fix: mark files that are attched as read (#1777)

Kujtim Hoxha created

Change summary

internal/agent/tools/edit.go                  | 21 +++---
internal/agent/tools/file.go                  | 53 ---------------
internal/agent/tools/multiedit.go             | 13 ++-
internal/agent/tools/multiedit_test.go        |  3 
internal/agent/tools/view.go                  |  3 
internal/agent/tools/write.go                 |  7 +
internal/filetracker/filetracker.go           | 70 +++++++++++++++++++++
internal/tui/components/chat/editor/editor.go | 10 +++
8 files changed, 106 insertions(+), 74 deletions(-)

Detailed changes

internal/agent/tools/edit.go 🔗

@@ -14,6 +14,7 @@ import (
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/diff"
 	"github.com/charmbracelet/crush/internal/filepathext"
+	"github.com/charmbracelet/crush/internal/filetracker"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/history"
 
@@ -159,8 +160,8 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool
 		slog.Error("Error creating file history version", "error", err)
 	}
 
-	recordFileWrite(filePath)
-	recordFileRead(filePath)
+	filetracker.RecordWrite(filePath)
+	filetracker.RecordRead(filePath)
 
 	return fantasy.WithResponseMetadata(
 		fantasy.NewTextResponse("File created: "+filePath),
@@ -186,12 +187,12 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
 		return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
 	}
 
-	if getLastReadTime(filePath).IsZero() {
+	if filetracker.LastReadTime(filePath).IsZero() {
 		return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
 	}
 
 	modTime := fileInfo.ModTime()
-	lastRead := getLastReadTime(filePath)
+	lastRead := filetracker.LastReadTime(filePath)
 	if modTime.After(lastRead) {
 		return fantasy.NewTextErrorResponse(
 			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
@@ -292,8 +293,8 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
 		slog.Error("Error creating file history version", "error", err)
 	}
 
-	recordFileWrite(filePath)
-	recordFileRead(filePath)
+	filetracker.RecordWrite(filePath)
+	filetracker.RecordRead(filePath)
 
 	return fantasy.WithResponseMetadata(
 		fantasy.NewTextResponse("Content deleted from file: "+filePath),
@@ -319,12 +320,12 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
 		return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
 	}
 
-	if getLastReadTime(filePath).IsZero() {
+	if filetracker.LastReadTime(filePath).IsZero() {
 		return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
 	}
 
 	modTime := fileInfo.ModTime()
-	lastRead := getLastReadTime(filePath)
+	lastRead := filetracker.LastReadTime(filePath)
 	if modTime.After(lastRead) {
 		return fantasy.NewTextErrorResponse(
 			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
@@ -427,8 +428,8 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
 		slog.Error("Error creating file history version", "error", err)
 	}
 
-	recordFileWrite(filePath)
-	recordFileRead(filePath)
+	filetracker.RecordWrite(filePath)
+	filetracker.RecordRead(filePath)
 
 	return fantasy.WithResponseMetadata(
 		fantasy.NewTextResponse("Content replaced in file: "+filePath),

internal/agent/tools/file.go 🔗

@@ -1,53 +0,0 @@
-package tools
-
-import (
-	"sync"
-	"time"
-)
-
-// File record to track when files were read/written
-type fileRecord struct {
-	path      string
-	readTime  time.Time
-	writeTime time.Time
-}
-
-var (
-	fileRecords     = make(map[string]fileRecord)
-	fileRecordMutex sync.RWMutex
-)
-
-func recordFileRead(path string) {
-	fileRecordMutex.Lock()
-	defer fileRecordMutex.Unlock()
-
-	record, exists := fileRecords[path]
-	if !exists {
-		record = fileRecord{path: path}
-	}
-	record.readTime = time.Now()
-	fileRecords[path] = record
-}
-
-func getLastReadTime(path string) time.Time {
-	fileRecordMutex.RLock()
-	defer fileRecordMutex.RUnlock()
-
-	record, exists := fileRecords[path]
-	if !exists {
-		return time.Time{}
-	}
-	return record.readTime
-}
-
-func recordFileWrite(path string) {
-	fileRecordMutex.Lock()
-	defer fileRecordMutex.Unlock()
-
-	record, exists := fileRecords[path]
-	if !exists {
-		record = fileRecord{path: path}
-	}
-	record.writeTime = time.Now()
-	fileRecords[path] = record
-}

internal/agent/tools/multiedit.go 🔗

@@ -14,6 +14,7 @@ import (
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/diff"
 	"github.com/charmbracelet/crush/internal/filepathext"
+	"github.com/charmbracelet/crush/internal/filetracker"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/history"
 	"github.com/charmbracelet/crush/internal/lsp"
@@ -206,8 +207,8 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call
 		slog.Error("Error creating file history version", "error", err)
 	}
 
-	recordFileWrite(params.FilePath)
-	recordFileRead(params.FilePath)
+	filetracker.RecordWrite(params.FilePath)
+	filetracker.RecordRead(params.FilePath)
 
 	var message string
 	if len(failedEdits) > 0 {
@@ -244,13 +245,13 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
 	}
 
 	// Check if file was read before editing
-	if getLastReadTime(params.FilePath).IsZero() {
+	if filetracker.LastReadTime(params.FilePath).IsZero() {
 		return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
 	}
 
 	// Check if file was modified since last read
 	modTime := fileInfo.ModTime()
-	lastRead := getLastReadTime(params.FilePath)
+	lastRead := filetracker.LastReadTime(params.FilePath)
 	if modTime.After(lastRead) {
 		return fantasy.NewTextErrorResponse(
 			fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
@@ -362,8 +363,8 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
 		slog.Error("Error creating file history version", "error", err)
 	}
 
-	recordFileWrite(params.FilePath)
-	recordFileRead(params.FilePath)
+	filetracker.RecordWrite(params.FilePath)
+	filetracker.RecordRead(params.FilePath)
 
 	var message string
 	if len(failedEdits) > 0 {

internal/agent/tools/multiedit_test.go 🔗

@@ -7,6 +7,7 @@ import (
 	"testing"
 
 	"github.com/charmbracelet/crush/internal/csync"
+	"github.com/charmbracelet/crush/internal/filetracker"
 	"github.com/charmbracelet/crush/internal/history"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/permission"
@@ -119,7 +120,7 @@ func TestMultiEditSequentialApplication(t *testing.T) {
 	_ = NewMultiEditTool(lspClients, permissions, files, tmpDir)
 
 	// Simulate reading the file first.
-	recordFileRead(testFile)
+	filetracker.RecordRead(testFile)
 
 	// Manually test the sequential application logic.
 	currentContent := content

internal/agent/tools/view.go 🔗

@@ -15,6 +15,7 @@ import (
 	"charm.land/fantasy"
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/filepathext"
+	"github.com/charmbracelet/crush/internal/filetracker"
 	"github.com/charmbracelet/crush/internal/lsp"
 	"github.com/charmbracelet/crush/internal/permission"
 )
@@ -194,7 +195,7 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
 			}
 			output += "\n</file>\n"
 			output += getDiagnostics(filePath, lspClients)
-			recordFileRead(filePath)
+			filetracker.RecordRead(filePath)
 			return fantasy.WithResponseMetadata(
 				fantasy.NewTextResponse(output),
 				ViewResponseMetadata{

internal/agent/tools/write.go 🔗

@@ -14,6 +14,7 @@ import (
 	"github.com/charmbracelet/crush/internal/csync"
 	"github.com/charmbracelet/crush/internal/diff"
 	"github.com/charmbracelet/crush/internal/filepathext"
+	"github.com/charmbracelet/crush/internal/filetracker"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/history"
 
@@ -72,7 +73,7 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
 				}
 
 				modTime := fileInfo.ModTime()
-				lastRead := getLastReadTime(filePath)
+				lastRead := filetracker.LastReadTime(filePath)
 				if modTime.After(lastRead) {
 					return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
 						filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
@@ -156,8 +157,8 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
 				slog.Error("Error creating file history version", "error", err)
 			}
 
-			recordFileWrite(filePath)
-			recordFileRead(filePath)
+			filetracker.RecordWrite(filePath)
+			filetracker.RecordRead(filePath)
 
 			notifyLSPs(ctx, lspClients, params.FilePath)
 

internal/filetracker/filetracker.go 🔗

@@ -0,0 +1,70 @@
+// Package filetracker tracks file read/write times to prevent editing files
+// that haven't been read, and to detect external modifications.
+//
+// TODO: Consider moving this to persistent storage (e.g., the database) to
+// preserve file access history across sessions.
+// We would need to make sure to handle the case where we reload a session and the underlying files did change.
+package filetracker
+
+import (
+	"sync"
+	"time"
+)
+
+// record tracks when a file was read/written.
+type record struct {
+	path      string
+	readTime  time.Time
+	writeTime time.Time
+}
+
+var (
+	records     = make(map[string]record)
+	recordMutex sync.RWMutex
+)
+
+// RecordRead records when a file was read.
+func RecordRead(path string) {
+	recordMutex.Lock()
+	defer recordMutex.Unlock()
+
+	rec, exists := records[path]
+	if !exists {
+		rec = record{path: path}
+	}
+	rec.readTime = time.Now()
+	records[path] = rec
+}
+
+// LastReadTime returns when a file was last read. Returns zero time if never
+// read.
+func LastReadTime(path string) time.Time {
+	recordMutex.RLock()
+	defer recordMutex.RUnlock()
+
+	rec, exists := records[path]
+	if !exists {
+		return time.Time{}
+	}
+	return rec.readTime
+}
+
+// RecordWrite records when a file was written.
+func RecordWrite(path string) {
+	recordMutex.Lock()
+	defer recordMutex.Unlock()
+
+	rec, exists := records[path]
+	if !exists {
+		rec = record{path: path}
+	}
+	rec.writeTime = time.Now()
+	records[path] = rec
+}
+
+// Reset clears all file tracking records. Useful for testing.
+func Reset() {
+	recordMutex.Lock()
+	defer recordMutex.Unlock()
+	records = make(map[string]record)
+}

internal/tui/components/chat/editor/editor.go 🔗

@@ -18,6 +18,7 @@ import (
 	tea "charm.land/bubbletea/v2"
 	"charm.land/lipgloss/v2"
 	"github.com/charmbracelet/crush/internal/app"
+	"github.com/charmbracelet/crush/internal/filetracker"
 	"github.com/charmbracelet/crush/internal/fsext"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/session"
@@ -202,11 +203,20 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
 				m.currentQuery = ""
 				m.completionsStartIndex = 0
 			}
+			absPath, _ := filepath.Abs(item.Path)
+			// Skip attachment if file was already read and hasn't been modified.
+			lastRead := filetracker.LastReadTime(absPath)
+			if !lastRead.IsZero() {
+				if info, err := os.Stat(item.Path); err == nil && !info.ModTime().After(lastRead) {
+					return m, nil
+				}
+			}
 			content, err := os.ReadFile(item.Path)
 			if err != nil {
 				// if it fails, let the LLM handle it later.
 				return m, nil
 			}
+			filetracker.RecordRead(absPath)
 			m.attachments = append(m.attachments, message.Attachment{
 				FilePath: item.Path,
 				FileName: filepath.Base(item.Path),