1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/diff"
 16	"github.com/charmbracelet/crush/internal/fsext"
 17	"github.com/charmbracelet/crush/internal/history"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/proto"
 21)
 22
 23type (
 24	MultiEditOperation = proto.MultiEditOperation
 25	MultiEditParams    = proto.MultiEditParams
 26
 27	MultiEditPermissionsParams = proto.MultiEditPermissionsParams
 28	MultiEditResponseMetadata  = proto.MultiEditResponseMetadata
 29)
 30
 31type multiEditTool struct {
 32	lspClients  *csync.Map[string, *lsp.Client]
 33	permissions permission.Service
 34	files       history.Service
 35	workingDir  string
 36}
 37
 38const MultiEditToolName = proto.MultiEditToolName
 39
 40//go:embed multiedit.md
 41var multieditDescription []byte
 42
 43func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) BaseTool {
 44	return &multiEditTool{
 45		lspClients:  lspClients,
 46		permissions: permissions,
 47		files:       files,
 48		workingDir:  workingDir,
 49	}
 50}
 51
 52func (m *multiEditTool) Name() string {
 53	return MultiEditToolName
 54}
 55
 56func (m *multiEditTool) Info() ToolInfo {
 57	return ToolInfo{
 58		Name:        MultiEditToolName,
 59		Description: string(multieditDescription),
 60		Parameters: map[string]any{
 61			"file_path": map[string]any{
 62				"type":        "string",
 63				"description": "The absolute path to the file to modify",
 64			},
 65			"edits": map[string]any{
 66				"type": "array",
 67				"items": map[string]any{
 68					"type": "object",
 69					"properties": map[string]any{
 70						"old_string": map[string]any{
 71							"type":        "string",
 72							"description": "The text to replace",
 73						},
 74						"new_string": map[string]any{
 75							"type":        "string",
 76							"description": "The text to replace it with",
 77						},
 78						"replace_all": map[string]any{
 79							"type":        "boolean",
 80							"default":     false,
 81							"description": "Replace all occurrences of old_string (default false).",
 82						},
 83					},
 84					"required":             []string{"old_string", "new_string"},
 85					"additionalProperties": false,
 86				},
 87				"minItems":    1,
 88				"description": "Array of edit operations to perform sequentially on the file",
 89			},
 90		},
 91		Required: []string{"file_path", "edits"},
 92	}
 93}
 94
 95func (m *multiEditTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 96	var params MultiEditParams
 97	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 98		return NewTextErrorResponse("invalid parameters"), nil
 99	}
100
101	if params.FilePath == "" {
102		return NewTextErrorResponse("file_path is required"), nil
103	}
104
105	if len(params.Edits) == 0 {
106		return NewTextErrorResponse("at least one edit operation is required"), nil
107	}
108
109	if !filepath.IsAbs(params.FilePath) {
110		params.FilePath = filepath.Join(m.workingDir, params.FilePath)
111	}
112
113	// Validate all edits before applying any
114	if err := m.validateEdits(params.Edits); err != nil {
115		return NewTextErrorResponse(err.Error()), nil
116	}
117
118	var response ToolResponse
119	var err error
120
121	// Handle file creation case (first edit has empty old_string)
122	if len(params.Edits) > 0 && params.Edits[0].OldString == "" {
123		response, err = m.processMultiEditWithCreation(ctx, params, call)
124	} else {
125		response, err = m.processMultiEditExistingFile(ctx, params, call)
126	}
127
128	if err != nil {
129		return response, err
130	}
131
132	if response.IsError {
133		return response, nil
134	}
135
136	// Notify LSP clients about the change
137	notifyLSPs(ctx, m.lspClients, params.FilePath)
138
139	// Wait for LSP diagnostics and add them to the response
140	text := fmt.Sprintf("<result>\n%s\n</result>\n", response.Content)
141	text += getDiagnostics(params.FilePath, m.lspClients)
142	response.Content = text
143	return response, nil
144}
145
146func (m *multiEditTool) validateEdits(edits []MultiEditOperation) error {
147	for i, edit := range edits {
148		if edit.OldString == edit.NewString {
149			return fmt.Errorf("edit %d: old_string and new_string are identical", i+1)
150		}
151		// Only the first edit can have empty old_string (for file creation)
152		if i > 0 && edit.OldString == "" {
153			return fmt.Errorf("edit %d: only the first edit can have empty old_string (for file creation)", i+1)
154		}
155	}
156	return nil
157}
158
159func (m *multiEditTool) processMultiEditWithCreation(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
160	// First edit creates the file
161	firstEdit := params.Edits[0]
162	if firstEdit.OldString != "" {
163		return NewTextErrorResponse("first edit must have empty old_string for file creation"), nil
164	}
165
166	// Check if file already exists
167	if _, err := os.Stat(params.FilePath); err == nil {
168		return NewTextErrorResponse(fmt.Sprintf("file already exists: %s", params.FilePath)), nil
169	} else if !os.IsNotExist(err) {
170		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
171	}
172
173	// Create parent directories
174	dir := filepath.Dir(params.FilePath)
175	if err := os.MkdirAll(dir, 0o755); err != nil {
176		return ToolResponse{}, fmt.Errorf("failed to create parent directories: %w", err)
177	}
178
179	// Start with the content from the first edit
180	currentContent := firstEdit.NewString
181
182	// Apply remaining edits to the content
183	for i := 1; i < len(params.Edits); i++ {
184		edit := params.Edits[i]
185		newContent, err := m.applyEditToContent(currentContent, edit)
186		if err != nil {
187			return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
188		}
189		currentContent = newContent
190	}
191
192	// Get session and message IDs
193	sessionID, messageID := GetContextValues(ctx)
194	if sessionID == "" || messageID == "" {
195		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
196	}
197
198	// Check permissions
199	_, additions, removals := diff.GenerateDiff("", currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
200
201	p := m.permissions.Request(permission.CreatePermissionRequest{
202		SessionID:   sessionID,
203		Path:        fsext.PathOrPrefix(params.FilePath, m.workingDir),
204		ToolCallID:  call.ID,
205		ToolName:    MultiEditToolName,
206		Action:      "write",
207		Description: fmt.Sprintf("Create file %s with %d edits", params.FilePath, len(params.Edits)),
208		Params: MultiEditPermissionsParams{
209			FilePath:   params.FilePath,
210			OldContent: "",
211			NewContent: currentContent,
212		},
213	})
214	if !p {
215		return ToolResponse{}, permission.ErrorPermissionDenied
216	}
217
218	// Write the file
219	err := os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
220	if err != nil {
221		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
222	}
223
224	// Update file history
225	_, err = m.files.Create(ctx, sessionID, params.FilePath, "")
226	if err != nil {
227		return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
228	}
229
230	_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
231	if err != nil {
232		slog.Debug("Error creating file history version", "error", err)
233	}
234
235	recordFileWrite(params.FilePath)
236	recordFileRead(params.FilePath)
237
238	return WithResponseMetadata(
239		NewTextResponse(fmt.Sprintf("File created with %d edits: %s", len(params.Edits), params.FilePath)),
240		MultiEditResponseMetadata{
241			OldContent:   "",
242			NewContent:   currentContent,
243			Additions:    additions,
244			Removals:     removals,
245			EditsApplied: len(params.Edits),
246		},
247	), nil
248}
249
250func (m *multiEditTool) processMultiEditExistingFile(ctx context.Context, params MultiEditParams, call ToolCall) (ToolResponse, error) {
251	// Validate file exists and is readable
252	fileInfo, err := os.Stat(params.FilePath)
253	if err != nil {
254		if os.IsNotExist(err) {
255			return NewTextErrorResponse(fmt.Sprintf("file not found: %s", params.FilePath)), nil
256		}
257		return ToolResponse{}, fmt.Errorf("failed to access file: %w", err)
258	}
259
260	if fileInfo.IsDir() {
261		return NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil
262	}
263
264	// Check if file was read before editing
265	if getLastReadTime(params.FilePath).IsZero() {
266		return NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
267	}
268
269	// Check if file was modified since last read
270	modTime := fileInfo.ModTime()
271	lastRead := getLastReadTime(params.FilePath)
272	if modTime.After(lastRead) {
273		return NewTextErrorResponse(
274			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
275				params.FilePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339),
276			)), nil
277	}
278
279	// Read current file content
280	content, err := os.ReadFile(params.FilePath)
281	if err != nil {
282		return ToolResponse{}, fmt.Errorf("failed to read file: %w", err)
283	}
284
285	oldContent, isCrlf := fsext.ToUnixLineEndings(string(content))
286	currentContent := oldContent
287
288	// Apply all edits sequentially
289	for i, edit := range params.Edits {
290		newContent, err := m.applyEditToContent(currentContent, edit)
291		if err != nil {
292			return NewTextErrorResponse(fmt.Sprintf("edit %d failed: %s", i+1, err.Error())), nil
293		}
294		currentContent = newContent
295	}
296
297	// Check if content actually changed
298	if oldContent == currentContent {
299		return NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil
300	}
301
302	// Get session and message IDs
303	sessionID, messageID := GetContextValues(ctx)
304	if sessionID == "" || messageID == "" {
305		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for editing file")
306	}
307
308	// Generate diff and check permissions
309	_, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, m.workingDir))
310	p := m.permissions.Request(permission.CreatePermissionRequest{
311		SessionID:   sessionID,
312		Path:        fsext.PathOrPrefix(params.FilePath, m.workingDir),
313		ToolCallID:  call.ID,
314		ToolName:    MultiEditToolName,
315		Action:      "write",
316		Description: fmt.Sprintf("Apply %d edits to file %s", len(params.Edits), params.FilePath),
317		Params: MultiEditPermissionsParams{
318			FilePath:   params.FilePath,
319			OldContent: oldContent,
320			NewContent: currentContent,
321		},
322	})
323	if !p {
324		return ToolResponse{}, permission.ErrorPermissionDenied
325	}
326
327	if isCrlf {
328		currentContent, _ = fsext.ToWindowsLineEndings(currentContent)
329	}
330
331	// Write the updated content
332	err = os.WriteFile(params.FilePath, []byte(currentContent), 0o644)
333	if err != nil {
334		return ToolResponse{}, fmt.Errorf("failed to write file: %w", err)
335	}
336
337	// Update file history
338	file, err := m.files.GetByPathAndSession(ctx, params.FilePath, sessionID)
339	if err != nil {
340		_, err = m.files.Create(ctx, sessionID, params.FilePath, oldContent)
341		if err != nil {
342			return ToolResponse{}, fmt.Errorf("error creating file history: %w", err)
343		}
344	}
345	if file.Content != oldContent {
346		// User manually changed the content, store an intermediate version
347		_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, oldContent)
348		if err != nil {
349			slog.Debug("Error creating file history version", "error", err)
350		}
351	}
352
353	// Store the new version
354	_, err = m.files.CreateVersion(ctx, sessionID, params.FilePath, currentContent)
355	if err != nil {
356		slog.Debug("Error creating file history version", "error", err)
357	}
358
359	recordFileWrite(params.FilePath)
360	recordFileRead(params.FilePath)
361
362	return WithResponseMetadata(
363		NewTextResponse(fmt.Sprintf("Applied %d edits to file: %s", len(params.Edits), params.FilePath)),
364		MultiEditResponseMetadata{
365			OldContent:   oldContent,
366			NewContent:   currentContent,
367			Additions:    additions,
368			Removals:     removals,
369			EditsApplied: len(params.Edits),
370		},
371	), nil
372}
373
374func (m *multiEditTool) applyEditToContent(content string, edit MultiEditOperation) (string, error) {
375	if edit.OldString == "" && edit.NewString == "" {
376		return content, nil
377	}
378
379	if edit.OldString == "" {
380		return "", fmt.Errorf("old_string cannot be empty for content replacement")
381	}
382
383	var newContent string
384	var replacementCount int
385
386	if edit.ReplaceAll {
387		newContent = strings.ReplaceAll(content, edit.OldString, edit.NewString)
388		replacementCount = strings.Count(content, edit.OldString)
389		if replacementCount == 0 {
390			return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
391		}
392	} else {
393		index := strings.Index(content, edit.OldString)
394		if index == -1 {
395			return "", fmt.Errorf("old_string not found in content. Make sure it matches exactly, including whitespace and line breaks")
396		}
397
398		lastIndex := strings.LastIndex(content, edit.OldString)
399		if index != lastIndex {
400			return "", fmt.Errorf("old_string appears multiple times in the content. Please provide more context to ensure a unique match, or set replace_all to true")
401		}
402
403		newContent = content[:index] + edit.NewString + content[index+len(edit.OldString):]
404		replacementCount = 1
405	}
406
407	return newContent, nil
408}