1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/diff"
16 "github.com/charmbracelet/crush/internal/fsext"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/proto"
19
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/permission"
22)
23
24type (
25 EditParams = proto.EditParams
26 EditPermissionsParams = proto.EditPermissionsParams
27 EditResponseMetadata = proto.EditResponseMetadata
28)
29
30type editTool struct {
31 lspClients *csync.Map[string, *lsp.Client]
32 permissions permission.Service
33 files history.Service
34 workingDir string
35}
36
37const EditToolName = proto.EditToolName
38
39//go:embed edit.md
40var editDescription []byte
41
42func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
43 return &editTool{
44 lspClients: lspClients,
45 permissions: permissions,
46 files: files,
47 workingDir: workingDir,
48 }
49}
50
51func (e *editTool) Name() string {
52 return EditToolName
53}
54
55func (e *editTool) Info() ToolInfo {
56 return ToolInfo{
57 Name: EditToolName,
58 Description: string(editDescription),
59 Parameters: map[string]any{
60 "file_path": map[string]any{
61 "type": "string",
62 "description": "The absolute path to the file to modify",
63 },
64 "old_string": map[string]any{
65 "type": "string",
66 "description": "The text to replace",
67 },
68 "new_string": map[string]any{
69 "type": "string",
70 "description": "The text to replace it with",
71 },
72 "replace_all": map[string]any{
73 "type": "boolean",
74 "description": "Replace all occurrences of old_string (default false)",
75 },
76 },
77 Required: []string{"file_path", "old_string", "new_string"},
78 }
79}
80
81func (e *editTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
82 var params EditParams
83 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
84 return NewTextErrorResponse("invalid parameters"), nil
85 }
86
87 if params.FilePath == "" {
88 return NewTextErrorResponse("file_path is required"), nil
89 }
90
91 if !filepath.IsAbs(params.FilePath) {
92 params.FilePath = filepath.Join(e.workingDir, params.FilePath)
93 }
94
95 var response ToolResponse
96 var err error
97
98 if params.OldString == "" {
99 response, err = e.createNewFile(ctx, params.FilePath, params.NewString, call)
100 if err != nil {
101 return response, err
102 }
103 }
104
105 if params.NewString == "" {
106 response, err = e.deleteContent(ctx, params.FilePath, params.OldString, params.ReplaceAll, call)
107 if err != nil {
108 return response, err
109 }
110 }
111
112 response, err = e.replaceContent(ctx, params.FilePath, params.OldString, params.NewString, params.ReplaceAll, call)
113 if err != nil {
114 return response, err
115 }
116 if response.IsError {
117 // Return early if there was an error during content replacement
118 // This prevents unnecessary LSP diagnostics processing
119 return response, nil
120 }
121
122 notifyLSPs(ctx, e.lspClients, params.FilePath)
123
124 text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
125 text += getDiagnostics(params.FilePath, e.lspClients)
126 response.Content = text
127 return response, nil
128}
129
130func (e *editTool) createNewFile(ctx context.Context, filePath, content string, call ToolCall) (ToolResponse, error) {
131 fileInfo, err := os.Stat(filePath)
132 if err == nil {
133 if fileInfo.IsDir() {
134 return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
135 }
136 return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", filePath)), nil
137 } else if !os.IsNotExist(err) {
138 return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
139 }
140
141 dir := filepath.Dir(filePath)
142 if err = os.MkdirAll(dir, 0o755); err != nil {
143 return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
144 }
145
146 sessionID, messageID := GetContextValues(ctx)
147 if sessionID == "" || messageID == "" {
148 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
149 }
150
151 _, additions, removals := diff.GenerateDiff(
152 "",
153 content,
154 strings.TrimPrefix(filePath, e.workingDir),
155 )
156 p := e.permissions.Request(
157 permission.CreatePermissionRequest{
158 SessionID: sessionID,
159 Path: fsext.PathOrPrefix(filePath, e.workingDir),
160 ToolCallID: call.ID,
161 ToolName: EditToolName,
162 Action: "write",
163 Description: fmt.Sprintf("Create file %s", filePath),
164 Params: EditPermissionsParams{
165 FilePath: filePath,
166 OldContent: "",
167 NewContent: content,
168 },
169 },
170 )
171 if !p {
172 return ToolResponse{}, permission.ErrorPermissionDenied
173 }
174
175 err = os.WriteFile(filePath, []byte(content), 0o644)
176 if err != nil {
177 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
178 }
179
180 // File can't be in the history so we create a new file history
181 _, err = e.files.Create(ctx, sessionID, filePath, "")
182 if err != nil {
183 // Log error but don't fail the operation
184 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
185 }
186
187 // Add the new content to the file history
188 _, err = e.files.CreateVersion(ctx, sessionID, filePath, content)
189 if err != nil {
190 // Log error but don't fail the operation
191 slog.Debug("Error creating file history version", "error", err)
192 }
193
194 recordFileWrite(filePath)
195 recordFileRead(filePath)
196
197 return WithResponseMetadata(
198 NewTextResponse("File created: "+filePath),
199 EditResponseMetadata{
200 OldContent: "",
201 NewContent: content,
202 Additions: additions,
203 Removals: removals,
204 },
205 ), nil
206}
207
208func (e *editTool) deleteContent(ctx context.Context, filePath, oldString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
209 fileInfo, err := os.Stat(filePath)
210 if err != nil {
211 if os.IsNotExist(err) {
212 return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
213 }
214 return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
215 }
216
217 if fileInfo.IsDir() {
218 return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
219 }
220
221 if getLastReadTime(filePath).IsZero() {
222 return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
223 }
224
225 modTime := fileInfo.ModTime()
226 lastRead := getLastReadTime(filePath)
227 if modTime.After(lastRead) {
228 return NewTextErrorResponse(
229 fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
230 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
231 )), nil
232 }
233
234 content, err := os.ReadFile(filePath)
235 if err != nil {
236 return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
237 }
238
239 oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
240
241 var newContent string
242 var deletionCount int
243
244 if replaceAll {
245 newContent = strings.ReplaceAll(oldContent, oldString, "")
246 deletionCount = strings.Count(oldContent, oldString)
247 if deletionCount == 0 {
248 return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
249 }
250 } else {
251 index := strings.Index(oldContent, oldString)
252 if index == -1 {
253 return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
254 }
255
256 lastIndex := strings.LastIndex(oldContent, oldString)
257 if index != lastIndex {
258 return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
259 }
260
261 newContent = oldContent[:index] + oldContent[index+len(oldString):]
262 deletionCount = 1
263 }
264
265 sessionID, messageID := GetContextValues(ctx)
266
267 if sessionID == "" || messageID == "" {
268 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
269 }
270
271 _, additions, removals := diff.GenerateDiff(
272 oldContent,
273 newContent,
274 strings.TrimPrefix(filePath, e.workingDir),
275 )
276
277 p := e.permissions.Request(
278 permission.CreatePermissionRequest{
279 SessionID: sessionID,
280 Path: fsext.PathOrPrefix(filePath, e.workingDir),
281 ToolCallID: call.ID,
282 ToolName: EditToolName,
283 Action: "write",
284 Description: fmt.Sprintf("Delete content from file %s", filePath),
285 Params: EditPermissionsParams{
286 FilePath: filePath,
287 OldContent: oldContent,
288 NewContent: newContent,
289 },
290 },
291 )
292 if !p {
293 return ToolResponse{}, permission.ErrorPermissionDenied
294 }
295
296 if isCrlf {
297 newContent, _ = fsext.ToWindowsLineEndings(newContent)
298 }
299
300 err = os.WriteFile(filePath, []byte(newContent), 0o644)
301 if err != nil {
302 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
303 }
304
305 // Check if file exists in history
306 file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
307 if err != nil {
308 _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
309 if err != nil {
310 // Log error but don't fail the operation
311 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
312 }
313 }
314 if file.Content != oldContent {
315 // User Manually changed the content store an intermediate version
316 _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
317 if err != nil {
318 slog.Debug("Error creating file history version", "error", err)
319 }
320 }
321 // Store the new version
322 _, err = e.files.CreateVersion(ctx, sessionID, filePath, "")
323 if err != nil {
324 slog.Debug("Error creating file history version", "error", err)
325 }
326
327 recordFileWrite(filePath)
328 recordFileRead(filePath)
329
330 return WithResponseMetadata(
331 NewTextResponse("Content deleted from file: "+filePath),
332 EditResponseMetadata{
333 OldContent: oldContent,
334 NewContent: newContent,
335 Additions: additions,
336 Removals: removals,
337 },
338 ), nil
339}
340
341func (e *editTool) replaceContent(ctx context.Context, filePath, oldString, newString string, replaceAll bool, call ToolCall) (ToolResponse, error) {
342 fileInfo, err := os.Stat(filePath)
343 if err != nil {
344 if os.IsNotExist(err) {
345 return NewTextErrorResponse(fmt.Sprintf("file not found: %s", filePath)), nil
346 }
347 return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
348 }
349
350 if fileInfo.IsDir() {
351 return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
352 }
353
354 if getLastReadTime(filePath).IsZero() {
355 return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
356 }
357
358 modTime := fileInfo.ModTime()
359 lastRead := getLastReadTime(filePath)
360 if modTime.After(lastRead) {
361 return NewTextErrorResponse(
362 fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
363 filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
364 )), nil
365 }
366
367 content, err := os.ReadFile(filePath)
368 if err != nil {
369 return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
370 }
371
372 oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
373
374 var newContent string
375 var replacementCount int
376
377 if replaceAll {
378 newContent = strings.ReplaceAll(oldContent, oldString, newString)
379 replacementCount = strings.Count(oldContent, oldString)
380 if replacementCount == 0 {
381 return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
382 }
383 } else {
384 index := strings.Index(oldContent, oldString)
385 if index == -1 {
386 return NewTextErrorResponse("old_string not found in file. Make sure it matches exactly, including whitespace and line breaks"), nil
387 }
388
389 lastIndex := strings.LastIndex(oldContent, oldString)
390 if index != lastIndex {
391 return NewTextErrorResponse("old_string appears multiple times in the file. Please provide more context to ensure a unique match, or set replace_all to true"), nil
392 }
393
394 newContent = oldContent[:index] + newString + oldContent[index+len(oldString):]
395 replacementCount = 1
396 }
397
398 if oldContent == newContent {
399 return NewTextErrorResponse("new content is the same as old content. No changes made."), nil
400 }
401 sessionID, messageID := GetContextValues(ctx)
402
403 if sessionID == "" || messageID == "" {
404 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
405 }
406 _, additions, removals := diff.GenerateDiff(
407 oldContent,
408 newContent,
409 strings.TrimPrefix(filePath, e.workingDir),
410 )
411
412 p := e.permissions.Request(
413 permission.CreatePermissionRequest{
414 SessionID: sessionID,
415 Path: fsext.PathOrPrefix(filePath, e.workingDir),
416 ToolCallID: call.ID,
417 ToolName: EditToolName,
418 Action: "write",
419 Description: fmt.Sprintf("Replace content in file %s", filePath),
420 Params: EditPermissionsParams{
421 FilePath: filePath,
422 OldContent: oldContent,
423 NewContent: newContent,
424 },
425 },
426 )
427 if !p {
428 return ToolResponse{}, permission.ErrorPermissionDenied
429 }
430
431 if isCrlf {
432 newContent, _ = fsext.ToWindowsLineEndings(newContent)
433 }
434
435 err = os.WriteFile(filePath, []byte(newContent), 0o644)
436 if err != nil {
437 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
438 }
439
440 // Check if file exists in history
441 file, err := e.files.GetByPathAndSession(ctx, filePath, sessionID)
442 if err != nil {
443 _, err = e.files.Create(ctx, sessionID, filePath, oldContent)
444 if err != nil {
445 // Log error but don't fail the operation
446 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
447 }
448 }
449 if file.Content != oldContent {
450 // User Manually changed the content store an intermediate version
451 _, err = e.files.CreateVersion(ctx, sessionID, filePath, oldContent)
452 if err != nil {
453 slog.Debug("Error creating file history version", "error", err)
454 }
455 }
456 // Store the new version
457 _, err = e.files.CreateVersion(ctx, sessionID, filePath, newContent)
458 if err != nil {
459 slog.Debug("Error creating file history version", "error", err)
460 }
461
462 recordFileWrite(filePath)
463 recordFileRead(filePath)
464
465 return WithResponseMetadata(
466 NewTextResponse("Content replaced in file: "+filePath),
467 EditResponseMetadata{
468 OldContent: oldContent,
469 NewContent: newContent,
470 Additions: additions,
471 Removals: removals,
472 }), nil
473}