1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/fsext"
 17	"github.com/charmbracelet/crush/internal/history"
 18
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/crush/internal/permission"
 21)
 22
 23type EditParams struct {
 24	FilePath   string `json:"file_path"`
 25	OldString  string `json:"old_string"`
 26	NewString  string `json:"new_string"`
 27	ReplaceAll bool   `json:"replace_all,omitempty"`
 28}
 29
 30type EditPermissionsParams struct {
 31	FilePath   string `json:"file_path"`
 32	OldContent string `json:"old_content,omitempty"`
 33	NewContent string `json:"new_content,omitempty"`
 34}
 35
 36type EditResponseMetadata struct {
 37	Additions  int    `json:"additions"`
 38	Removals   int    `json:"removals"`
 39	OldContent string `json:"old_content,omitempty"`
 40	NewContent string `json:"new_content,omitempty"`
 41}
 42
 43type editTool struct {
 44	lspClients  *csync.Map[string, *lsp.Client]
 45	permissions permission.Service
 46	files       history.Service
 47	workingDir  string
 48}
 49
 50const EditToolName = "edit"
 51
 52//go:embed edit.md
 53var editDescription []byte
 54
 55func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 56	return &editTool{
 57		lspClients:  lspClients,
 58		permissions: permissions,
 59		files:       files,
 60		workingDir:  workingDir,
 61	}
 62}
 63
 64func (e *editTool) Name() string {
 65	return EditToolName
 66}
 67
 68func (e *editTool) Info() ToolInfo {
 69	return ToolInfo{
 70		Name:        EditToolName,
 71		Description: string(editDescription),
 72		Parameters: map[string]any{
 73			"file_path": map[string]any{
 74				"type":        "string",
 75				"description": "The absolute path to the file to modify",
 76			},
 77			"old_string": map[string]any{
 78				"type":        "string",
 79				"description": "The text to replace",
 80			},
 81			"new_string": map[string]any{
 82				"type":        "string",
 83				"description": "The text to replace it with",
 84			},
 85			"replace_all": map[string]any{
 86				"type":        "boolean",
 87				"description": "Replace all occurrences of old_string (default false)",
 88			},
 89		},
 90		Required: []string{"file_path", "old_string", "new_string"},
 91	}
 92}
 93
 94func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 95	var params EditParams
 96	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 97		return NewTextErrorResponse("invalid parameters"), nil
 98	}
 99
100	if params.FilePath == "" {
101		return NewTextErrorResponse("file_path is required"), nil
102	}
103
104	if !filepath.IsAbs(params.FilePath) {
105		params.FilePath = filepath.Join(e.workingDir, params.FilePath)
106	}
107
108	var response ToolResponse
109	var err error
110
111	if params.OldString == "" {
112		response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call)
113		if err != nil {
114			return response, err
115		}
116	}
117
118	if params.NewString == "" {
119		response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call)
120		if err != nil {
121			return response, err
122		}
123	}
124
125	response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call)
126	if err != nil {
127		return response, err
128	}
129	if response.IsError {
130		// Return early if there was an error during content replacement
131		// This prevents unnecessary LSP diagnostics processing
132		return response, nil
133	}
134
135	notifyLSPs(ctx, e.lspClients, params.FilePath)
136
137	text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
138	text += getDiagnostics(params.FilePath, e.lspClients)
139	response.Content = text
140	return response, nil
141}
142
143func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) {
144	fileInfo, err := os.Stat(filePath)
145	if err == nil {
146		if fileInfo.IsDir() {
147			return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
148		}
149		return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil
150	} else if !os.IsNotExist(err) {
151		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
152	}
153
154	dir := filepath.Dir(filePath)
155	if err = os.MkdirAll(dir, 0o755); err != nil {
156		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
157	}
158
159	sessionID, messageID := GetContextValues(ctx)
160	if sessionID == "" || messageID == "" {
161		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
162	}
163
164	_, additions, removals := diff.GenerateDiff(
165		"",
166		content,
167		strings.TrimPrefix(filePath, e.workingDir),
168	)
169	p := e.permissions.Request(
170		permission.CreatePermissionRequest{
171			SessionID:   sessionID,
172			Path:        fsext.PathOrPrefix(filePath, e.workingDir),
173			ToolCallID:  call.ID,
174			ToolName:    EditToolName,
175			Action:      "write",
176			Description: fmt.Sprintf("Create file %s", filePath),
177			Params: EditPermissionsParams{
178				FilePath:   filePath,
179				OldContent: "",
180				NewContent: content,
181			},
182		},
183	)
184	if !p {
185		return ToolResponse{}, permission.ErrorPermissionDenied
186	}
187
188	err = os.WriteFile(filePath, []byte(content), 0o644)
189	if err != nil {
190		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
191	}
192
193	// File can't be in the history so we create a new file history
194	_, err = e.files.Create(ctx, sessionID, filePath, "")
195	if err != nil {
196		// Log error but don't fail the operation
197		return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
198	}
199
200	// Add the new content to the file history
201	_, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
202	if err != nil {
203		// Log error but don't fail the operation
204		slog.Debug("Error creating file history version", "error", err)
205	}
206
207	recordFileWrite(filePath)
208	recordFileRead(filePath)
209
210	return WithResponseMetadata(
211		NewTextResponse("File created: "+filePath),
212		EditResponseMetadata{
213			OldContent: "",
214			NewContent: content,
215			Additions:  additions,
216			Removals:   removals,
217		},
218	), nil
219}
220
221func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
222	fileInfo, err := os.Stat(filePath)
223	if err != nil {
224		if os.IsNotExist(err) {
225			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
226		}
227		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
228	}
229
230	if fileInfo.IsDir() {
231		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
232	}
233
234	if getLastReadTime(filePath).IsZero() {
235		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
236	}
237
238	modTime := fileInfo.ModTime()
239	lastRead := getLastReadTime(filePath)
240	if modTime.After(lastRead) {
241		return NewTextErrorResponse(
242			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
243				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
244			)), nil
245	}
246
247	content, err := os.ReadFile(filePath)
248	if err != nil {
249		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
250	}
251
252	oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
253
254	var newContent string
255	var deletionCount int
256
257	if replaceAll {
258		newContent = strings.ReplaceAll(oldContent, oldString, "")
259		deletionCount = strings.Count(oldContent, oldString)
260		if deletionCount == 0 {
261			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
262		}
263	} else {
264		index := strings.Index(oldContent, oldString)
265		if index == -1 {
266			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
267		}
268
269		lastIndex := strings.LastIndex(oldContent, oldString)
270		if index != lastIndex {
271			return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
272		}
273
274		newContent = oldContent[:index] + oldContent[index+len(oldString):]
275		deletionCount = 1
276	}
277
278	sessionID, messageID := GetContextValues(ctx)
279
280	if sessionID == "" || messageID == "" {
281		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
282	}
283
284	_, additions, removals := diff.GenerateDiff(
285		oldContent,
286		newContent,
287		strings.TrimPrefix(filePath, e.workingDir),
288	)
289
290	p := e.permissions.Request(
291		permission.CreatePermissionRequest{
292			SessionID:   sessionID,
293			Path:        fsext.PathOrPrefix(filePath, e.workingDir),
294			ToolCallID:  call.ID,
295			ToolName:    EditToolName,
296			Action:      "write",
297			Description: fmt.Sprintf("Delete content from file %s", filePath),
298			Params: EditPermissionsParams{
299				FilePath:   filePath,
300				OldContent: oldContent,
301				NewContent: newContent,
302			},
303		},
304	)
305	if !p {
306		return ToolResponse{}, permission.ErrorPermissionDenied
307	}
308
309	if isCrlf {
310		newContent, _ = fsext.ToWindowsLineEndings(newContent)
311	}
312
313	err = os.WriteFile(filePath, []byte(newContent), 0o644)
314	if err != nil {
315		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
316	}
317
318	// Check if file exists in history
319	file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
320	if err != nil {
321		_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
322		if err != nil {
323			// Log error but don't fail the operation
324			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
325		}
326	}
327	if file.Content != oldContent {
328		// User Manually changed the content store an intermediate version
329		_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
330		if err != nil {
331			slog.Debug("Error creating file history version", "error", err)
332		}
333	}
334	// Store the new version
335	_, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
336	if err != nil {
337		slog.Debug("Error creating file history version", "error", err)
338	}
339
340	recordFileWrite(filePath)
341	recordFileRead(filePath)
342
343	return WithResponseMetadata(
344		NewTextResponse("Content deleted from file: "+filePath),
345		EditResponseMetadata{
346			OldContent: oldContent,
347			NewContent: newContent,
348			Additions:  additions,
349			Removals:   removals,
350		},
351	), nil
352}
353
354func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
355	fileInfo, err := os.Stat(filePath)
356	if err != nil {
357		if os.IsNotExist(err) {
358			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
359		}
360		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
361	}
362
363	if fileInfo.IsDir() {
364		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
365	}
366
367	if getLastReadTime(filePath).IsZero() {
368		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
369	}
370
371	modTime := fileInfo.ModTime()
372	lastRead := getLastReadTime(filePath)
373	if modTime.After(lastRead) {
374		return NewTextErrorResponse(
375			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
376				filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
377			)), nil
378	}
379
380	content, err := os.ReadFile(filePath)
381	if err != nil {
382		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
383	}
384
385	oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
386
387	var newContent string
388	var replacementCount int
389
390	if replaceAll {
391		newContent = strings.ReplaceAll(oldContent, oldString, newString)
392		replacementCount = strings.Count(oldContent, oldString)
393		if replacementCount == 0 {
394			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
395		}
396	} else {
397		index := strings.Index(oldContent, oldString)
398		if index == -1 {
399			return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
400		}
401
402		lastIndex := strings.LastIndex(oldContent, oldString)
403		if index != lastIndex {
404			return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
405		}
406
407		newContent = oldContent[:index] + newString + oldContent[index+len(oldString):]
408		replacementCount = 1
409	}
410
411	if oldContent == newContent {
412		return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
413	}
414	sessionID, messageID := GetContextValues(ctx)
415
416	if sessionID == "" || messageID == "" {
417		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
418	}
419	_, additions, removals := diff.GenerateDiff(
420		oldContent,
421		newContent,
422		strings.TrimPrefix(filePath, e.workingDir),
423	)
424
425	p := e.permissions.Request(
426		permission.CreatePermissionRequest{
427			SessionID:   sessionID,
428			Path:        fsext.PathOrPrefix(filePath, e.workingDir),
429			ToolCallID:  call.ID,
430			ToolName:    EditToolName,
431			Action:      "write",
432			Description: fmt.Sprintf("Replace content in file %s", filePath),
433			Params: EditPermissionsParams{
434				FilePath:   filePath,
435				OldContent: oldContent,
436				NewContent: newContent,
437			},
438		},
439	)
440	if !p {
441		return ToolResponse{}, permission.ErrorPermissionDenied
442	}
443
444	if isCrlf {
445		newContent, _ = fsext.ToWindowsLineEndings(newContent)
446	}
447
448	err = os.WriteFile(filePath, []byte(newContent), 0o644)
449	if err != nil {
450		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
451	}
452
453	// Check if file exists in history
454	file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
455	if err != nil {
456		_, err = e.files.Create(ctx, sessionID, filePath, oldContent)
457		if err != nil {
458			// Log error but don't fail the operation
459			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
460		}
461	}
462	if file.Content != oldContent {
463		// User Manually changed the content store an intermediate version
464		_, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
465		if err != nil {
466			slog.Debug("Error creating file history version", "error", err)
467		}
468	}
469	// Store the new version
470	_, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
471	if err != nil {
472		slog.Debug("Error creating file history version", "error", err)
473	}
474
475	recordFileWrite(filePath)
476	recordFileRead(filePath)
477
478	return WithResponseMetadata(
479		NewTextResponse("Content replaced in file: "+filePath),
480		EditResponseMetadata{
481			OldContent: oldContent,
482			NewContent: newContent,
483			Additions:  additions,
484			Removals:   removals,
485		}), nil
486}