1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "os"
10 "path/filepath"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/diff"
16 "github.com/charmbracelet/crush/internal/fsext"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/proto"
21)
22
23type (
24 MultiEditOperation = proto.MultiEditOperation
25 MultiEditParams = proto.MultiEditParams
26
27 MultiEditPermissionsParams = proto.MultiEditPermissionsParams
28 MultiEditResponseMetadata = proto.MultiEditResponseMetadata
29)
30
31type multiEditTool struct {
32 lspClients *csync.Map[string, *lsp.Client]
33 permissions permission.Service
34 files history.Service
35 workingDir string
36}
37
38const MultiEditToolName = proto.MultiEditToolName
39
40//go:embed multiedit.md
41var multieditDescription []byte
42
43func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
44 return &multiEditTool{
45 lspClients: lspClients,
46 permissions: permissions,
47 files: files,
48 workingDir: workingDir,
49 }
50}
51
52func (m *multiEditTool) Name() string {
53 return MultiEditToolName
54}
55
56func (m *multiEditTool) Info() ToolInfo {
57 return ToolInfo{
58 Name: MultiEditToolName,
59 Description: string(multieditDescription),
60 Parameters: map[string]any{
61 "file_path": map[string]any{
62 "type": "string",
63 "description": "The absolute path to the file to modify",
64 },
65 "edits": map[string]any{
66 "type": "array",
67 "items": map[string]any{
68 "type": "object",
69 "properties": map[string]any{
70 "old_string": map[string]any{
71 "type": "string",
72 "description": "The text to replace",
73 },
74 "new_string": map[string]any{
75 "type": "string",
76 "description": "The text to replace it with",
77 },
78 "replace_all": map[string]any{
79 "type": "boolean",
80 "default": false,
81 "description": "Replace all occurrences of old_string (default false).",
82 },
83 },
84 "required": []string{"old_string", "new_string"},
85 "additionalProperties": false,
86 },
87 "minItems": 1,
88 "description": "Array of edit operations to perform sequentially on the file",
89 },
90 },
91 Required: []string{"file_path", "edits"},
92 }
93}
94
95func (m *multiEditTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
96 var params MultiEditParams
97 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
98 return NewTextErrorResponse("invalid parameters"), nil
99 }
100
101 if params.FilePath == "" {
102 return NewTextErrorResponse("file_path is required"), nil
103 }
104
105 if len(params.Edits) == 0 {
106 return NewTextErrorResponse("at least one edit operation is required"), nil
107 }
108
109 if !filepath.IsAbs(params.FilePath) {
110 params.FilePath = filepath.Join(m.workingDir, params.FilePath)
111 }
112
113 // Validate all edits before applying any
114 if err := m.validateEdits(params.Edits); err != nil {
115 return NewTextErrorResponse(err.Error()), nil
116 }
117
118 var response ToolResponse
119 var err error
120
121 // Handle file creation case (first edit has empty old_string)
122 if len(params.Edits) > 0 && params.Edits[0].OldString == "" {
123 response, err = m.processMultiEditWithCreation(ctx, params, call)
124 } else {
125 response, err = m.processMultiEditExistingFile(ctx, params, call)
126 }
127
128 if err != nil {
129 return response, err
130 }
131
132 if response.IsError {
133 return response, nil
134 }
135
136 // Notify LSP clients about the change
137 notifyLSPs(ctx, m.lspClients, params.FilePath)
138
139 // Wait for LSP diagnostics and add them to the response
140 text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
141 text += getDiagnostics(params.FilePath, m.lspClients)
142 response.Content = text
143 return response, nil
144}
145
146func (m *multiEditTool) validateEdits(edits []MultiEditOperation) error {
147 for i, edit := range edits {
148 if edit.OldString == edit.NewString {
149 return fmt.Errorf("edit %d: old_string and new_string are identical", i+1)
150 }
151 // Only the first edit can have empty old_string (for file creation)
152 if i > 0 && edit.OldString == "" {
153 return fmt.Errorf("edit %d: only the first edit can have empty old_string (for file creation)", i+1)
154 }
155 }
156 return nil
157}
158
159func (m *multiEditTool) processMultiEditWithCreation(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
160 // First edit creates the file
161 firstEdit := params.Edits[0]
162 if firstEdit.OldString != "" {
163 return NewTextErrorResponse("first edit must have empty old_string for file creation"), nil
164 }
165
166 // Check if file already exists
167 if _, err := os.Stat(params.FilePath); err == nil {
168 return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", params.FilePath)), nil
169 } else if !os.IsNotExist(err) {
170 return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
171 }
172
173 // Create parent directories
174 dir := filepath.Dir(params.FilePath)
175 if err := os.MkdirAll(dir, 0o755); err != nil {
176 return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
177 }
178
179 // Start with the content from the first edit
180 currentContent := firstEdit.NewString
181
182 // Apply remaining edits to the content
183 for i := 1; i < len(params.Edits); i++ {
184 edit := params.Edits[i]
185 newContent, err := m.applyEditToContent(currentContent, edit)
186 if err != nil {
187 return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
188 }
189 currentContent = newContent
190 }
191
192 // Get session and message IDs
193 sessionID, messageID := GetContextValues(ctx)
194 if sessionID == "" || messageID == "" {
195 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
196 }
197
198 // Check permissions
199 _, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
200
201 p := m.permissions.Request(permission.CreatePermissionRequest{
202 SessionID: sessionID,
203 Path: fsext.PathOrPrefix(params.FilePath, m.workingDir),
204 ToolCallID: call.ID,
205 ToolName: MultiEditToolName,
206 Action: "write",
207 Description: fmt.Sprintf("Create file %s with %d edits", params.FilePath, len(params.Edits)),
208 Params: MultiEditPermissionsParams{
209 FilePath: params.FilePath,
210 OldContent: "",
211 NewContent: currentContent,
212 },
213 })
214 if !p {
215 return ToolResponse{}, permission.ErrorPermissionDenied
216 }
217
218 // Write the file
219 err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
220 if err != nil {
221 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
222 }
223
224 // Update file history
225 _, err = m.files.Create(ctx, sessionID, params.FilePath, "")
226 if err != nil {
227 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
228 }
229
230 _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
231 if err != nil {
232 slog.Debug("Error creating file history version", "error", err)
233 }
234
235 recordFileWrite(params.FilePath)
236 recordFileRead(params.FilePath)
237
238 return WithResponseMetadata(
239 NewTextResponse(fmt.Sprintf("File created with %d edits: %s", len(params.Edits), params.FilePath)),
240 MultiEditResponseMetadata{
241 OldContent: "",
242 NewContent: currentContent,
243 Additions: additions,
244 Removals: removals,
245 EditsApplied: len(params.Edits),
246 },
247 ), nil
248}
249
250func (m *multiEditTool) processMultiEditExistingFile(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
251 // Validate file exists and is readable
252 fileInfo, err := os.Stat(params.FilePath)
253 if err != nil {
254 if os.IsNotExist(err) {
255 return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil
256 }
257 return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
258 }
259
260 if fileInfo.IsDir() {
261 return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil
262 }
263
264 // Check if file was read before editing
265 if getLastReadTime(params.FilePath).IsZero() {
266 return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
267 }
268
269 // Check if file was modified since last read
270 modTime := fileInfo.ModTime()
271 lastRead := getLastReadTime(params.FilePath)
272 if modTime.After(lastRead) {
273 return NewTextErrorResponse(
274 fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
275 params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
276 )), nil
277 }
278
279 // Read current file content
280 content, err := os.ReadFile(params.FilePath)
281 if err != nil {
282 return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
283 }
284
285 oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
286 currentContent := oldContent
287
288 // Apply all edits sequentially
289 for i, edit := range params.Edits {
290 newContent, err := m.applyEditToContent(currentContent, edit)
291 if err != nil {
292 return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
293 }
294 currentContent = newContent
295 }
296
297 // Check if content actually changed
298 if oldContent == currentContent {
299 return NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil
300 }
301
302 // Get session and message IDs
303 sessionID, messageID := GetContextValues(ctx)
304 if sessionID == "" || messageID == "" {
305 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for editing file")
306 }
307
308 // Generate diff and check permissions
309 _, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
310 p := m.permissions.Request(permission.CreatePermissionRequest{
311 SessionID: sessionID,
312 Path: fsext.PathOrPrefix(params.FilePath, m.workingDir),
313 ToolCallID: call.ID,
314 ToolName: MultiEditToolName,
315 Action: "write",
316 Description: fmt.Sprintf("Apply %d edits to file %s", len(params.Edits), params.FilePath),
317 Params: MultiEditPermissionsParams{
318 FilePath: params.FilePath,
319 OldContent: oldContent,
320 NewContent: currentContent,
321 },
322 })
323 if !p {
324 return ToolResponse{}, permission.ErrorPermissionDenied
325 }
326
327 if isCrlf {
328 currentContent, _ = fsext.ToWindowsLineEndings(currentContent)
329 }
330
331 // Write the updated content
332 err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
333 if err != nil {
334 return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
335 }
336
337 // Update file history
338 file, err := m.files.GetByPathAndSession(ctx, params.FilePath, sessionID)
339 if err != nil {
340 _, err = m.files.Create(ctx, sessionID, params.FilePath, oldContent)
341 if err != nil {
342 return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
343 }
344 }
345 if file.Content != oldContent {
346 // User manually changed the content, store an intermediate version
347 _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent)
348 if err != nil {
349 slog.Debug("Error creating file history version", "error", err)
350 }
351 }
352
353 // Store the new version
354 _, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
355 if err != nil {
356 slog.Debug("Error creating file history version", "error", err)
357 }
358
359 recordFileWrite(params.FilePath)
360 recordFileRead(params.FilePath)
361
362 return WithResponseMetadata(
363 NewTextResponse(fmt.Sprintf("Applied %d edits to file: %s", len(params.Edits), params.FilePath)),
364 MultiEditResponseMetadata{
365 OldContent: oldContent,
366 NewContent: currentContent,
367 Additions: additions,
368 Removals: removals,
369 EditsApplied: len(params.Edits),
370 },
371 ), nil
372}
373
374func (m *multiEditTool) applyEditToContent(content string, edit MultiEditOperation) (string, error) {
375 if edit.OldString == "" && edit.NewString == "" {
376 return content, nil
377 }
378
379 if edit.OldString == "" {
380 return "", fmt.Errorf("old_string cannot be empty for content replacement")
381 }
382
383 var newContent string
384 var replacementCount int
385
386 if edit.ReplaceAll {
387 newContent = strings.ReplaceAll(content, edit.OldString, edit.NewString)
388 replacementCount = strings.Count(content, edit.OldString)
389 if replacementCount == 0 {
390 return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
391 }
392 } else {
393 index := strings.Index(content, edit.OldString)
394 if index == -1 {
395 return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
396 }
397
398 lastIndex := strings.LastIndex(content, edit.OldString)
399 if index != lastIndex {
400 return "", fmt.Errorf("old_string appears multiple times in the content. Please provide more context to ensure a unique match, or set replace_all to true")
401 }
402
403 newContent = content[:index] + edit.NewString + content[index+len(edit.OldString):]
404 replacementCount = 1
405 }
406
407 return newContent, nil
408}