1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/fsext"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/permission"
 20)
 21
 22type MultiEditOperation struct {
 23	OldString  string `json:"old_string"`
 24	NewString  string `json:"new_string"`
 25	ReplaceAll bool   `json:"replace_all,omitempty"`
 26}
 27
 28type MultiEditParams struct {
 29	FilePath string               `json:"file_path"`
 30	Edits    []MultiEditOperation `json:"edits"`
 31}
 32
 33type MultiEditPermissionsParams struct {
 34	FilePath   string `json:"file_path"`
 35	OldContent string `json:"old_content,omitempty"`
 36	NewContent string `json:"new_content,omitempty"`
 37}
 38
 39type MultiEditResponseMetadata struct {
 40	Additions    int    `json:"additions"`
 41	Removals     int    `json:"removals"`
 42	OldContent   string `json:"old_content,omitempty"`
 43	NewContent   string `json:"new_content,omitempty"`
 44	EditsApplied int    `json:"edits_applied"`
 45}
 46
 47type multiEditTool struct {
 48	lspClients  *csync.Map[string, *lsp.Client]
 49	permissions permission.Service
 50	files       history.Service
 51	workingDir  string
 52}
 53
 54const MultiEditToolName = "multiedit"
 55
 56//go:embed multiedit.md
 57var multieditDescription []byte
 58
 59func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 60	return &multiEditTool{
 61		lspClients:  lspClients,
 62		permissions: permissions,
 63		files:       files,
 64		workingDir:  workingDir,
 65	}
 66}
 67
 68func (m *multiEditTool) Name() string {
 69	return MultiEditToolName
 70}
 71
 72func (m *multiEditTool) Info() ToolInfo {
 73	return ToolInfo{
 74		Name:        MultiEditToolName,
 75		Description: string(multieditDescription),
 76		Parameters: map[string]any{
 77			"file_path": map[string]any{
 78				"type":        "string",
 79				"description": "The absolute path to the file to modify",
 80			},
 81			"edits": map[string]any{
 82				"type": "array",
 83				"items": map[string]any{
 84					"type": "object",
 85					"properties": map[string]any{
 86						"old_string": map[string]any{
 87							"type":        "string",
 88							"description": "The text to replace",
 89						},
 90						"new_string": map[string]any{
 91							"type":        "string",
 92							"description": "The text to replace it with",
 93						},
 94						"replace_all": map[string]any{
 95							"type":        "boolean",
 96							"default":     false,
 97							"description": "Replace all occurrences of old_string (default false).",
 98						},
 99					},
100					"required":             []string{"old_string", "new_string"},
101					"additionalProperties": false,
102				},
103				"minItems":    1,
104				"description": "Array of edit operations to perform sequentially on the file",
105			},
106		},
107		Required: []string{"file_path", "edits"},
108	}
109}
110
111func (m *multiEditTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
112	var params MultiEditParams
113	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
114		return NewTextErrorResponse("invalid parameters"), nil
115	}
116
117	if params.FilePath == "" {
118		return NewTextErrorResponse("file_path is required"), nil
119	}
120
121	if len(params.Edits) == 0 {
122		return NewTextErrorResponse("at least one edit operation is required"), nil
123	}
124
125	if !filepath.IsAbs(params.FilePath) {
126		params.FilePath = filepath.Join(m.workingDir, params.FilePath)
127	}
128
129	// Validate all edits before applying any
130	if err := m.validateEdits(params.Edits); err != nil {
131		return NewTextErrorResponse(err.Error()), nil
132	}
133
134	var response ToolResponse
135	var err error
136
137	// Handle file creation case (first edit has empty old_string)
138	if len(params.Edits) > 0 && params.Edits[0].OldString == "" {
139		response, err = m.processMultiEditWithCreation(ctx, params, call)
140	} else {
141		response, err = m.processMultiEditExistingFile(ctx, params, call)
142	}
143
144	if err != nil {
145		return response, err
146	}
147
148	if response.IsError {
149		return response, nil
150	}
151
152	// Notify LSP clients about the change
153	notifyLSPs(ctx, m.lspClients, params.FilePath)
154
155	// Wait for LSP diagnostics and add them to the response
156	text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
157	text += getDiagnostics(params.FilePath, m.lspClients)
158	response.Content = text
159	return response, nil
160}
161
162func (m *multiEditTool) validateEdits(edits []MultiEditOperation) error {
163	for i, edit := range edits {
164		if edit.OldString == edit.NewString {
165			return fmt.Errorf("edit %d: old_string and new_string are identical", i+1)
166		}
167		// Only the first edit can have empty old_string (for file creation)
168		if i > 0 && edit.OldString == "" {
169			return fmt.Errorf("edit %d: only the first edit can have empty old_string (for file creation)", i+1)
170		}
171	}
172	return nil
173}
174
175func (m *multiEditTool) processMultiEditWithCreation(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
176	// First edit creates the file
177	firstEdit := params.Edits[0]
178	if firstEdit.OldString != "" {
179		return NewTextErrorResponse("first edit must have empty old_string for file creation"), nil
180	}
181
182	// Check if file already exists
183	if _, err := os.Stat(params.FilePath); err == nil {
184		return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", params.FilePath)), nil
185	} else if !os.IsNotExist(err) {
186		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
187	}
188
189	// Create parent directories
190	dir := filepath.Dir(params.FilePath)
191	if err := os.MkdirAll(dir, 0o755); err != nil {
192		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
193	}
194
195	// Start with the content from the first edit
196	currentContent := firstEdit.NewString
197
198	// Apply remaining edits to the content
199	for i := 1; i < len(params.Edits); i++ {
200		edit := params.Edits[i]
201		newContent, err := m.applyEditToContent(currentContent, edit)
202		if err != nil {
203			return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
204		}
205		currentContent = newContent
206	}
207
208	// Get session and message IDs
209	sessionID, messageID := GetContextValues(ctx)
210	if sessionID == "" || messageID == "" {
211		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
212	}
213
214	// Check permissions
215	_, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
216
217	p := m.permissions.Request(permission.CreatePermissionRequest{
218		SessionID:   sessionID,
219		Path:        fsext.PathOrPrefix(params.FilePath, m.workingDir),
220		ToolCallID:  call.ID,
221		ToolName:    MultiEditToolName,
222		Action:      "write",
223		Description: fmt.Sprintf("Create file %s with %d edits", params.FilePath, len(params.Edits)),
224		Params: MultiEditPermissionsParams{
225			FilePath:   params.FilePath,
226			OldContent: "",
227			NewContent: currentContent,
228		},
229	})
230	if !p {
231		return ToolResponse{}, permission.ErrorPermissionDenied
232	}
233
234	// Write the file
235	err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
236	if err != nil {
237		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
238	}
239
240	// Update file history
241	_, err = m.files.Create(ctx, sessionID, params.FilePath, "")
242	if err != nil {
243		return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
244	}
245
246	_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
247	if err != nil {
248		slog.Debug("Error creating file history version", "error", err)
249	}
250
251	recordFileWrite(params.FilePath)
252	recordFileRead(params.FilePath)
253
254	return WithResponseMetadata(
255		NewTextResponse(fmt.Sprintf("File created with %d edits: %s", len(params.Edits), params.FilePath)),
256		MultiEditResponseMetadata{
257			OldContent:   "",
258			NewContent:   currentContent,
259			Additions:    additions,
260			Removals:     removals,
261			EditsApplied: len(params.Edits),
262		},
263	), nil
264}
265
266func (m *multiEditTool) processMultiEditExistingFile(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
267	// Validate file exists and is readable
268	fileInfo, err := os.Stat(params.FilePath)
269	if err != nil {
270		if os.IsNotExist(err) {
271			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil
272		}
273		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
274	}
275
276	if fileInfo.IsDir() {
277		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil
278	}
279
280	// Check if file was read before editing
281	if getLastReadTime(params.FilePath).IsZero() {
282		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
283	}
284
285	// Check if file was modified since last read
286	modTime := fileInfo.ModTime()
287	lastRead := getLastReadTime(params.FilePath)
288	if modTime.After(lastRead) {
289		return NewTextErrorResponse(
290			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
291				params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
292			)), nil
293	}
294
295	// Read current file content
296	content, err := os.ReadFile(params.FilePath)
297	if err != nil {
298		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
299	}
300
301	oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
302	currentContent := oldContent
303
304	// Apply all edits sequentially
305	for i, edit := range params.Edits {
306		newContent, err := m.applyEditToContent(currentContent, edit)
307		if err != nil {
308			return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
309		}
310		currentContent = newContent
311	}
312
313	// Check if content actually changed
314	if oldContent == currentContent {
315		return NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil
316	}
317
318	// Get session and message IDs
319	sessionID, messageID := GetContextValues(ctx)
320	if sessionID == "" || messageID == "" {
321		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for editing file")
322	}
323
324	// Generate diff and check permissions
325	_, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
326	p := m.permissions.Request(permission.CreatePermissionRequest{
327		SessionID:   sessionID,
328		Path:        fsext.PathOrPrefix(params.FilePath, m.workingDir),
329		ToolCallID:  call.ID,
330		ToolName:    MultiEditToolName,
331		Action:      "write",
332		Description: fmt.Sprintf("Apply %d edits to file %s", len(params.Edits), params.FilePath),
333		Params: MultiEditPermissionsParams{
334			FilePath:   params.FilePath,
335			OldContent: oldContent,
336			NewContent: currentContent,
337		},
338	})
339	if !p {
340		return ToolResponse{}, permission.ErrorPermissionDenied
341	}
342
343	if isCrlf {
344		currentContent, _ = fsext.ToWindowsLineEndings(currentContent)
345	}
346
347	// Write the updated content
348	err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
349	if err != nil {
350		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
351	}
352
353	// Update file history
354	file, err := m.files.GetByPathAndSession(ctx, params.FilePath, sessionID)
355	if err != nil {
356		_, err = m.files.Create(ctx, sessionID, params.FilePath, oldContent)
357		if err != nil {
358			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
359		}
360	}
361	if file.Content != oldContent {
362		// User manually changed the content, store an intermediate version
363		_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent)
364		if err != nil {
365			slog.Debug("Error creating file history version", "error", err)
366		}
367	}
368
369	// Store the new version
370	_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
371	if err != nil {
372		slog.Debug("Error creating file history version", "error", err)
373	}
374
375	recordFileWrite(params.FilePath)
376	recordFileRead(params.FilePath)
377
378	return WithResponseMetadata(
379		NewTextResponse(fmt.Sprintf("Applied %d edits to file: %s", len(params.Edits), params.FilePath)),
380		MultiEditResponseMetadata{
381			OldContent:   oldContent,
382			NewContent:   currentContent,
383			Additions:    additions,
384			Removals:     removals,
385			EditsApplied: len(params.Edits),
386		},
387	), nil
388}
389
390func (m *multiEditTool) applyEditToContent(content string, edit MultiEditOperation) (string, error) {
391	if edit.OldString == "" && edit.NewString == "" {
392		return content, nil
393	}
394
395	if edit.OldString == "" {
396		return "", fmt.Errorf("old_string cannot be empty for content replacement")
397	}
398
399	var newContent string
400	var replacementCount int
401
402	if edit.ReplaceAll {
403		newContent = strings.ReplaceAll(content, edit.OldString, edit.NewString)
404		replacementCount = strings.Count(content, edit.OldString)
405		if replacementCount == 0 {
406			return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
407		}
408	} else {
409		index := strings.Index(content, edit.OldString)
410		if index == -1 {
411			return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
412		}
413
414		lastIndex := strings.LastIndex(content, edit.OldString)
415		if index != lastIndex {
416			return "", fmt.Errorf("old_string appears multiple times in the content. Please provide more context to ensure a unique match, or set replace_all to true")
417		}
418
419		newContent = content[:index] + edit.NewString + content[index+len(edit.OldString):]
420		replacementCount = 1
421	}
422
423	return newContent, nil
424}