Detailed changes
@@ -168,7 +168,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (
tools.NewGlobTool(tmpDir),
tools.NewGrepTool(tmpDir),
tools.NewSourcegraphTool(client),
- tools.NewViewTool(c.lspClients, c.permissions, tmpDir),
+ tools.NewViewTool(c.lspClients, c.permissions, c.filetracker, tmpDir),
}
agent := NewSessionAgent(SessionAgentOptions{
@@ -20,6 +20,7 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
@@ -37,6 +38,7 @@ type fakeEnv struct {
messages message.Service
permissions permission.Service
history history.Service
+ filetracker *filetracker.Service
lspClients *csync.Map[string, *lsp.Client]
}
@@ -117,6 +119,7 @@ func testEnv(t *testing.T) fakeEnv {
permissions := permission.NewPermissionService(workingDir, true, []string{})
history := history.NewService(q, conn)
+ filetrackerService := filetracker.NewService(q)
lspClients := csync.NewMap[string, *lsp.Client]()
t.Cleanup(func() {
@@ -130,6 +133,7 @@ func testEnv(t *testing.T) fakeEnv {
messages,
permissions,
history,
+ &filetrackerService,
lspClients,
}
}
@@ -200,15 +204,15 @@ func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel
allTools := []fantasy.AgentTool{
tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName),
tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()),
- tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
- tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir),
+ tools.NewEditTool(env.lspClients, env.permissions, env.history, *env.filetracker, env.workingDir),
+ tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, *env.filetracker, env.workingDir),
tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()),
tools.NewGlobTool(env.workingDir),
tools.NewGrepTool(env.workingDir),
tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls),
tools.NewSourcegraphTool(r.GetDefaultClient()),
- tools.NewViewTool(env.lspClients, env.permissions, env.workingDir),
- tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir),
+ tools.NewViewTool(env.lspClients, env.permissions, *env.filetracker, env.workingDir),
+ tools.NewWriteTool(env.lspClients, env.permissions, env.history, *env.filetracker, env.workingDir),
}
return testSessionAgent(env, large, small, systemPrompt, allTools...), nil
@@ -22,6 +22,7 @@ import (
"github.com/charmbracelet/crush/internal/agent/tools"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/log"
"github.com/charmbracelet/crush/internal/lsp"
@@ -64,6 +65,7 @@ type coordinator struct {
messages message.Service
permissions permission.Service
history history.Service
+ filetracker filetracker.Service
lspClients *csync.Map[string, *lsp.Client]
currentAgent SessionAgent
@@ -79,6 +81,7 @@ func NewCoordinator(
messages message.Service,
permissions permission.Service,
history history.Service,
+ filetracker filetracker.Service,
lspClients *csync.Map[string, *lsp.Client],
) (Coordinator, error) {
c := &coordinator{
@@ -87,6 +90,7 @@ func NewCoordinator(
messages: messages,
permissions: permissions,
history: history,
+ filetracker: filetracker,
lspClients: lspClients,
agents: make(map[string]SessionAgent),
}
@@ -393,16 +397,16 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
tools.NewJobOutputTool(),
tools.NewJobKillTool(),
tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
- tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
- tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
+ tools.NewEditTool(c.lspClients, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
+ tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
tools.NewGlobTool(c.cfg.WorkingDir()),
tools.NewGrepTool(c.cfg.WorkingDir()),
tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
tools.NewSourcegraphTool(nil),
tools.NewTodosTool(c.sessions),
- tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
- tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
+ tools.NewViewTool(c.lspClients, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
+ tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
)
if len(c.cfg.LSP) > 0 {
@@ -56,10 +56,17 @@ type editContext struct {
ctx context.Context
permissions permission.Service
files history.Service
+ filetracker filetracker.Service
workingDir string
}
-func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
+func NewEditTool(
+ lspClients *csync.Map[string, *lsp.Client],
+ permissions permission.Service,
+ files history.Service,
+ filetracker filetracker.Service,
+ workingDir string,
+) fantasy.AgentTool {
return fantasy.NewAgentTool(
EditToolName,
string(editDescription),
@@ -73,7 +80,7 @@ func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
var response fantasy.ToolResponse
var err error
- editCtx := editContext{ctx, permissions, files, workingDir}
+ editCtx := editContext{ctx, permissions, files, filetracker, workingDir}
if params.OldString == "" {
response, err = createNewFile(editCtx, params.FilePath, params.NewString, call)
@@ -168,8 +175,7 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool
slog.Error("Error creating file history version", "error", err)
}
- filetracker.RecordWrite(filePath)
- filetracker.RecordRead(filePath)
+ edit.filetracker.RecordRead(edit.ctx, sessionID, filePath)
return fantasy.WithResponseMetadata(
fantasy.NewTextResponse("File created: "+filePath),
@@ -195,12 +201,17 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
}
- if filetracker.LastReadTime(filePath).IsZero() {
+ sessionID := GetSessionFromContext(edit.ctx)
+ if sessionID == "" {
+ return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for deleting content")
+ }
+
+ lastRead := edit.filetracker.LastReadTime(edit.ctx, sessionID, filePath)
+ if lastRead.IsZero() {
return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
}
- modTime := fileInfo.ModTime()
- lastRead := filetracker.LastReadTime(filePath)
+ modTime := fileInfo.ModTime().Truncate(time.Second)
if modTime.After(lastRead) {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
@@ -236,12 +247,6 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
newContent = oldContent[:index] + oldContent[index+len(oldString):]
}
- sessionID := GetSessionFromContext(edit.ctx)
-
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for deleting content")
- }
-
_, additions, removals := diff.GenerateDiff(
oldContent,
newContent,
@@ -301,8 +306,7 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool
slog.Error("Error creating file history version", "error", err)
}
- filetracker.RecordWrite(filePath)
- filetracker.RecordRead(filePath)
+ edit.filetracker.RecordRead(edit.ctx, sessionID, filePath)
return fantasy.WithResponseMetadata(
fantasy.NewTextResponse("Content deleted from file: "+filePath),
@@ -328,12 +332,17 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil
}
- if filetracker.LastReadTime(filePath).IsZero() {
+ sessionID := GetSessionFromContext(edit.ctx)
+ if sessionID == "" {
+ return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for edit a file")
+ }
+
+ lastRead := edit.filetracker.LastReadTime(edit.ctx, sessionID, filePath)
+ if lastRead.IsZero() {
return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
}
- modTime := fileInfo.ModTime()
- lastRead := filetracker.LastReadTime(filePath)
+ modTime := fileInfo.ModTime().Truncate(time.Second)
if modTime.After(lastRead) {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
@@ -369,11 +378,6 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
if oldContent == newContent {
return fantasy.NewTextErrorResponse("new content is the same as old content. No changes made."), nil
}
- sessionID := GetSessionFromContext(edit.ctx)
-
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
- }
_, additions, removals := diff.GenerateDiff(
oldContent,
newContent,
@@ -433,8 +437,7 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep
slog.Error("Error creating file history version", "error", err)
}
- filetracker.RecordWrite(filePath)
- filetracker.RecordRead(filePath)
+ edit.filetracker.RecordRead(edit.ctx, sessionID, filePath)
return fantasy.WithResponseMetadata(
fantasy.NewTextResponse("Content replaced in file: "+filePath),
@@ -58,7 +58,13 @@ const MultiEditToolName = "multiedit"
//go:embed multiedit.md
var multieditDescription []byte
-func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
+func NewMultiEditTool(
+ lspClients *csync.Map[string, *lsp.Client],
+ permissions permission.Service,
+ files history.Service,
+ filetracker filetracker.Service,
+ workingDir string,
+) fantasy.AgentTool {
return fantasy.NewAgentTool(
MultiEditToolName,
string(multieditDescription),
@@ -81,7 +87,7 @@ func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions pe
var response fantasy.ToolResponse
var err error
- editCtx := editContext{ctx, permissions, files, workingDir}
+ editCtx := editContext{ctx, permissions, files, filetracker, workingDir}
// Handle file creation case (first edit has empty old_string)
if len(params.Edits) > 0 && params.Edits[0].OldString == "" {
response, err = processMultiEditWithCreation(editCtx, params, call)
@@ -210,8 +216,7 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call
slog.Error("Error creating file history version", "error", err)
}
- filetracker.RecordWrite(params.FilePath)
- filetracker.RecordRead(params.FilePath)
+ edit.filetracker.RecordRead(edit.ctx, sessionID, params.FilePath)
var message string
if len(failedEdits) > 0 {
@@ -247,14 +252,19 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil
}
+ sessionID := GetSessionFromContext(edit.ctx)
+ if sessionID == "" {
+ return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for editing file")
+ }
+
// Check if file was read before editing
- if filetracker.LastReadTime(params.FilePath).IsZero() {
+ lastRead := edit.filetracker.LastReadTime(edit.ctx, sessionID, params.FilePath)
+ if lastRead.IsZero() {
return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil
}
- // Check if file was modified since last read
- modTime := fileInfo.ModTime()
- lastRead := filetracker.LastReadTime(params.FilePath)
+ // Check if file was modified since last read.
+ modTime := fileInfo.ModTime().Truncate(time.Second)
if modTime.After(lastRead) {
return fantasy.NewTextErrorResponse(
fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)",
@@ -301,12 +311,6 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
return fantasy.NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil
}
- // Get session and message IDs
- sessionID := GetSessionFromContext(edit.ctx)
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for editing file")
- }
-
// Generate diff and check permissions
_, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, edit.workingDir))
@@ -369,8 +373,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call
slog.Error("Error creating file history version", "error", err)
}
- filetracker.RecordWrite(params.FilePath)
- filetracker.RecordRead(params.FilePath)
+ edit.filetracker.RecordRead(edit.ctx, sessionID, params.FilePath)
var message string
if len(failedEdits) > 0 {
@@ -6,10 +6,7 @@ import (
"path/filepath"
"testing"
- "github.com/charmbracelet/crush/internal/csync"
- "github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/history"
- "github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/permission"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/stretchr/testify/require"
@@ -111,17 +108,6 @@ func TestMultiEditSequentialApplication(t *testing.T) {
err := os.WriteFile(testFile, []byte(content), 0o644)
require.NoError(t, err)
- // Mock components.
- lspClients := csync.NewMap[string, *lsp.Client]()
- permissions := &mockPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()}
- files := &mockHistoryService{Broker: pubsub.NewBroker[history.File]()}
-
- // Create multiedit tool.
- _ = NewMultiEditTool(lspClients, permissions, files, tmpDir)
-
- // Simulate reading the file first.
- filetracker.RecordRead(testFile)
-
// Manually test the sequential application logic.
currentContent := content
@@ -47,7 +47,13 @@ const (
MaxLineLength = 2000
)
-func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string, skillsPaths ...string) fantasy.AgentTool {
+func NewViewTool(
+ lspClients *csync.Map[string, *lsp.Client],
+ permissions permission.Service,
+ filetracker filetracker.Service,
+ workingDir string,
+ skillsPaths ...string,
+) fantasy.AgentTool {
return fantasy.NewAgentTool(
ViewToolName,
string(viewDescription),
@@ -74,13 +80,13 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
isOutsideWorkDir := err != nil || strings.HasPrefix(relPath, "..")
isSkillFile := isInSkillsPath(absFilePath, skillsPaths)
+ sessionID := GetSessionFromContext(ctx)
+ if sessionID == "" {
+ return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory")
+ }
+
// Request permission for files outside working directory, unless it's a skill file.
if isOutsideWorkDir && !isSkillFile {
- sessionID := GetSessionFromContext(ctx)
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory")
- }
-
granted, err := permissions.Request(ctx,
permission.CreatePermissionRequest{
SessionID: sessionID,
@@ -190,7 +196,7 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss
}
output += "\n</file>\n"
output += getDiagnostics(filePath, lspClients)
- filetracker.RecordRead(filePath)
+ filetracker.RecordRead(ctx, sessionID, filePath)
return fantasy.WithResponseMetadata(
fantasy.NewTextResponse(output),
ViewResponseMetadata{
@@ -44,7 +44,13 @@ type WriteResponseMetadata struct {
const WriteToolName = "write"
-func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool {
+func NewWriteTool(
+ lspClients *csync.Map[string, *lsp.Client],
+ permissions permission.Service,
+ files history.Service,
+ filetracker filetracker.Service,
+ workingDir string,
+) fantasy.AgentTool {
return fantasy.NewAgentTool(
WriteToolName,
string(writeDescription),
@@ -57,6 +63,11 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
return fantasy.NewTextErrorResponse("content is required"), nil
}
+ sessionID := GetSessionFromContext(ctx)
+ if sessionID == "" {
+ return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
+ }
+
filePath := filepathext.SmartJoin(workingDir, params.FilePath)
fileInfo, err := os.Stat(filePath)
@@ -65,8 +76,8 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil
}
- modTime := fileInfo.ModTime()
- lastRead := filetracker.LastReadTime(filePath)
+ modTime := fileInfo.ModTime().Truncate(time.Second)
+ lastRead := filetracker.LastReadTime(ctx, sessionID, filePath)
if modTime.After(lastRead) {
return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.",
filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil
@@ -93,11 +104,6 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
}
}
- sessionID := GetSessionFromContext(ctx)
- if sessionID == "" {
- return fantasy.ToolResponse{}, fmt.Errorf("session_id is required")
- }
-
diff, additions, removals := diff.GenerateDiff(
oldContent,
params.Content,
@@ -153,8 +159,7 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis
slog.Error("Error creating file history version", "error", err)
}
- filetracker.RecordWrite(filePath)
- filetracker.RecordRead(filePath)
+ filetracker.RecordRead(ctx, sessionID, filePath)
notifyLSPs(ctx, lspClients, params.FilePath)
@@ -23,6 +23,7 @@ import (
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
"github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/format"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/log"
@@ -53,6 +54,7 @@ type App struct {
Messages message.Service
History history.Service
Permissions permission.Service
+ FileTracker filetracker.Service
AgentCoordinator agent.Coordinator
@@ -87,6 +89,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
Messages: messages,
History: files,
Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools),
+ FileTracker: filetracker.NewService(q),
LSPClients: csync.NewMap[string, *lsp.Client](),
globalCtx: ctx,
@@ -460,6 +463,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error {
app.Messages,
app.Permissions,
app.History,
+ app.FileTracker,
app.LSPClients,
)
if err != nil {
@@ -57,6 +57,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.getFileByPathAndSessionStmt, err = db.PrepareContext(ctx, getFileByPathAndSession); err != nil {
return nil, fmt.Errorf("error preparing query GetFileByPathAndSession: %w", err)
}
+ if q.getFileReadStmt, err = db.PrepareContext(ctx, getFileRead); err != nil {
+ return nil, fmt.Errorf("error preparing query GetFileRead: %w", err)
+ }
if q.getHourDayHeatmapStmt, err = db.PrepareContext(ctx, getHourDayHeatmap); err != nil {
return nil, fmt.Errorf("error preparing query GetHourDayHeatmap: %w", err)
}
@@ -111,6 +114,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.listUserMessagesBySessionStmt, err = db.PrepareContext(ctx, listUserMessagesBySession); err != nil {
return nil, fmt.Errorf("error preparing query ListUserMessagesBySession: %w", err)
}
+ if q.recordFileReadStmt, err = db.PrepareContext(ctx, recordFileRead); err != nil {
+ return nil, fmt.Errorf("error preparing query RecordFileRead: %w", err)
+ }
if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil {
return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err)
}
@@ -180,6 +186,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing getFileByPathAndSessionStmt: %w", cerr)
}
}
+ if q.getFileReadStmt != nil {
+ if cerr := q.getFileReadStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing getFileReadStmt: %w", cerr)
+ }
+ }
if q.getHourDayHeatmapStmt != nil {
if cerr := q.getHourDayHeatmapStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing getHourDayHeatmapStmt: %w", cerr)
@@ -270,6 +281,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing listUserMessagesBySessionStmt: %w", cerr)
}
}
+ if q.recordFileReadStmt != nil {
+ if cerr := q.recordFileReadStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing recordFileReadStmt: %w", cerr)
+ }
+ }
if q.updateMessageStmt != nil {
if cerr := q.updateMessageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateMessageStmt: %w", cerr)
@@ -335,6 +351,7 @@ type Queries struct {
getAverageResponseTimeStmt *sql.Stmt
getFileStmt *sql.Stmt
getFileByPathAndSessionStmt *sql.Stmt
+ getFileReadStmt *sql.Stmt
getHourDayHeatmapStmt *sql.Stmt
getMessageStmt *sql.Stmt
getRecentActivityStmt *sql.Stmt
@@ -353,6 +370,7 @@ type Queries struct {
listNewFilesStmt *sql.Stmt
listSessionsStmt *sql.Stmt
listUserMessagesBySessionStmt *sql.Stmt
+ recordFileReadStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
updateSessionTitleAndUsageStmt *sql.Stmt
@@ -373,6 +391,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
getAverageResponseTimeStmt: q.getAverageResponseTimeStmt,
getFileStmt: q.getFileStmt,
getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
+ getFileReadStmt: q.getFileReadStmt,
getHourDayHeatmapStmt: q.getHourDayHeatmapStmt,
getMessageStmt: q.getMessageStmt,
getRecentActivityStmt: q.getRecentActivityStmt,
@@ -391,6 +410,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
listNewFilesStmt: q.listNewFilesStmt,
listSessionsStmt: q.listSessionsStmt,
listUserMessagesBySessionStmt: q.listUserMessagesBySessionStmt,
+ recordFileReadStmt: q.recordFileReadStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
@@ -0,0 +1,20 @@
+-- +goose Up
+-- +goose StatementBegin
+CREATE TABLE IF NOT EXISTS read_files (
+ session_id TEXT NOT NULL CHECK (session_id != ''),
+ path TEXT NOT NULL CHECK (path != ''),
+ read_at INTEGER NOT NULL, -- Unix timestamp in seconds when file was last read
+ FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE,
+ PRIMARY KEY (path, session_id)
+);
+
+CREATE INDEX IF NOT EXISTS idx_read_files_session_id ON read_files (session_id);
+CREATE INDEX IF NOT EXISTS idx_read_files_path ON read_files (path);
+-- +goose StatementEnd
+
+-- +goose Down
+-- +goose StatementBegin
+DROP INDEX IF EXISTS idx_read_files_path;
+DROP INDEX IF EXISTS idx_read_files_session_id;
+DROP TABLE IF EXISTS read_files;
+-- +goose StatementEnd
@@ -31,6 +31,12 @@ type Message struct {
IsSummaryMessage int64 `json:"is_summary_message"`
}
+type ReadFile struct {
+ SessionID string `json:"session_id"`
+ Path string `json:"path"`
+ ReadAt int64 `json:"read_at"` // Unix timestamp when file was last read
+}
+
type Session struct {
ID string `json:"id"`
ParentSessionID sql.NullString `json:"parent_session_id"`
@@ -20,6 +20,7 @@ type Querier interface {
GetAverageResponseTime(ctx context.Context) (int64, error)
GetFile(ctx context.Context, id string) (File, error)
GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error)
+ GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error)
GetHourDayHeatmap(ctx context.Context) ([]GetHourDayHeatmapRow, error)
GetMessage(ctx context.Context, id string) (Message, error)
GetRecentActivity(ctx context.Context) ([]GetRecentActivityRow, error)
@@ -38,6 +39,7 @@ type Querier interface {
ListNewFiles(ctx context.Context) ([]File, error)
ListSessions(ctx context.Context) ([]Session, error)
ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error)
+ RecordFileRead(ctx context.Context, arg RecordFileReadParams) error
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
@@ -0,0 +1,57 @@
+// Code generated by sqlc. DO NOT EDIT.
+// versions:
+// sqlc v1.30.0
+// source: read_files.sql
+
+package db
+
+import (
+ "context"
+)
+
+const getFileRead = `-- name: GetFileRead :one
+SELECT session_id, path, read_at FROM read_files
+WHERE session_id = ? AND path = ? LIMIT 1
+`
+
+type GetFileReadParams struct {
+ SessionID string `json:"session_id"`
+ Path string `json:"path"`
+}
+
+func (q *Queries) GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error) {
+ row := q.queryRow(ctx, q.getFileReadStmt, getFileRead, arg.SessionID, arg.Path)
+ var i ReadFile
+ err := row.Scan(
+ &i.SessionID,
+ &i.Path,
+ &i.ReadAt,
+ )
+ return i, err
+}
+
+const recordFileRead = `-- name: RecordFileRead :exec
+INSERT INTO read_files (
+ session_id,
+ path,
+ read_at
+) VALUES (
+ ?,
+ ?,
+ strftime('%s', 'now')
+) ON CONFLICT(path, session_id) DO UPDATE SET
+ read_at = excluded.read_at
+`
+
+type RecordFileReadParams struct {
+ SessionID string `json:"session_id"`
+ Path string `json:"path"`
+}
+
+func (q *Queries) RecordFileRead(ctx context.Context, arg RecordFileReadParams) error {
+ _, err := q.exec(ctx, q.recordFileReadStmt, recordFileRead,
+ arg.SessionID,
+ arg.Path,
+ )
+ return err
+}
@@ -0,0 +1,15 @@
+-- name: RecordFileRead :exec
+INSERT INTO read_files (
+ session_id,
+ path,
+ read_at
+) VALUES (
+ ?,
+ ?,
+ strftime('%s', 'now')
+) ON CONFLICT(path, session_id) DO UPDATE SET
+ read_at = excluded.read_at;
+
+-- name: GetFileRead :one
+SELECT * FROM read_files
+WHERE session_id = ? AND path = ? LIMIT 1;
@@ -1,70 +0,0 @@
-// Package filetracker tracks file read/write times to prevent editing files
-// that haven't been read, and to detect external modifications.
-//
-// TODO: Consider moving this to persistent storage (e.g., the database) to
-// preserve file access history across sessions.
-// We would need to make sure to handle the case where we reload a session and the underlying files did change.
-package filetracker
-
-import (
- "sync"
- "time"
-)
-
-// record tracks when a file was read/written.
-type record struct {
- path string
- readTime time.Time
- writeTime time.Time
-}
-
-var (
- records = make(map[string]record)
- recordMutex sync.RWMutex
-)
-
-// RecordRead records when a file was read.
-func RecordRead(path string) {
- recordMutex.Lock()
- defer recordMutex.Unlock()
-
- rec, exists := records[path]
- if !exists {
- rec = record{path: path}
- }
- rec.readTime = time.Now()
- records[path] = rec
-}
-
-// LastReadTime returns when a file was last read. Returns zero time if never
-// read.
-func LastReadTime(path string) time.Time {
- recordMutex.RLock()
- defer recordMutex.RUnlock()
-
- rec, exists := records[path]
- if !exists {
- return time.Time{}
- }
- return rec.readTime
-}
-
-// RecordWrite records when a file was written.
-func RecordWrite(path string) {
- recordMutex.Lock()
- defer recordMutex.Unlock()
-
- rec, exists := records[path]
- if !exists {
- rec = record{path: path}
- }
- rec.writeTime = time.Now()
- records[path] = rec
-}
-
-// Reset clears all file tracking records. Useful for testing.
-func Reset() {
- recordMutex.Lock()
- defer recordMutex.Unlock()
- records = make(map[string]record)
-}
@@ -0,0 +1,70 @@
+// Package filetracker provides functionality to track file reads in sessions.
+package filetracker
+
+import (
+ "context"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/db"
+)
+
+// Service defines the interface for tracking file reads in sessions.
+type Service interface {
+ // RecordRead records when a file was read.
+ RecordRead(ctx context.Context, sessionID, path string)
+
+ // LastReadTime returns when a file was last read.
+ // Returns zero time if never read.
+ LastReadTime(ctx context.Context, sessionID, path string) time.Time
+}
+
+type service struct {
+ q *db.Queries
+}
+
+// NewService creates a new file tracker service.
+func NewService(q *db.Queries) Service {
+ return &service{q: q}
+}
+
+// RecordRead records when a file was read.
+func (s *service) RecordRead(ctx context.Context, sessionID, path string) {
+ if err := s.q.RecordFileRead(ctx, db.RecordFileReadParams{
+ SessionID: sessionID,
+ Path: relpath(path),
+ }); err != nil {
+ slog.Error("Error recording file read", "error", err, "file", path)
+ }
+}
+
+// LastReadTime returns when a file was last read.
+// Returns zero time if never read.
+func (s *service) LastReadTime(ctx context.Context, sessionID, path string) time.Time {
+ readFile, err := s.q.GetFileRead(ctx, db.GetFileReadParams{
+ SessionID: sessionID,
+ Path: relpath(path),
+ })
+ if err != nil {
+ return time.Time{}
+ }
+
+ return time.Unix(readFile.ReadAt, 0)
+}
+
+func relpath(path string) string {
+ path = filepath.Clean(path)
+ basepath, err := os.Getwd()
+ if err != nil {
+ slog.Warn("Error getting basepath", "error", err)
+ return path
+ }
+ relpath, err := filepath.Rel(basepath, path)
+ if err != nil {
+ slog.Warn("Error getting relpath", "error", err)
+ return path
+ }
+ return relpath
+}
@@ -0,0 +1,116 @@
+package filetracker
+
+import (
+ "context"
+ "testing"
+ "testing/synctest"
+ "time"
+
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/stretchr/testify/require"
+)
+
+type testEnv struct {
+ ctx context.Context
+ q *db.Queries
+ svc Service
+}
+
+func setupTest(t *testing.T) *testEnv {
+ t.Helper()
+
+ conn, err := db.Connect(t.Context(), t.TempDir())
+ require.NoError(t, err)
+ t.Cleanup(func() { conn.Close() })
+
+ q := db.New(conn)
+ return &testEnv{
+ ctx: t.Context(),
+ q: q,
+ svc: NewService(q),
+ }
+}
+
+func (e *testEnv) createSession(t *testing.T, sessionID string) {
+ t.Helper()
+ _, err := e.q.CreateSession(e.ctx, db.CreateSessionParams{
+ ID: sessionID,
+ Title: "Test Session",
+ })
+ require.NoError(t, err)
+}
+
+func TestService_RecordRead(t *testing.T) {
+ env := setupTest(t)
+
+ sessionID := "test-session-1"
+ path := "/path/to/file.go"
+ env.createSession(t, sessionID)
+
+ env.svc.RecordRead(env.ctx, sessionID, path)
+
+ lastRead := env.svc.LastReadTime(env.ctx, sessionID, path)
+ require.False(t, lastRead.IsZero(), "expected non-zero time after recording read")
+ require.WithinDuration(t, time.Now(), lastRead, 2*time.Second)
+}
+
+func TestService_LastReadTime_NotFound(t *testing.T) {
+ env := setupTest(t)
+
+ lastRead := env.svc.LastReadTime(env.ctx, "nonexistent-session", "/nonexistent/path")
+ require.True(t, lastRead.IsZero(), "expected zero time for unread file")
+}
+
+func TestService_RecordRead_UpdatesTimestamp(t *testing.T) {
+ env := setupTest(t)
+
+ sessionID := "test-session-2"
+ path := "/path/to/file.go"
+ env.createSession(t, sessionID)
+
+ env.svc.RecordRead(env.ctx, sessionID, path)
+ firstRead := env.svc.LastReadTime(env.ctx, sessionID, path)
+ require.False(t, firstRead.IsZero())
+
+ synctest.Test(t, func(t *testing.T) {
+ time.Sleep(100 * time.Millisecond)
+ synctest.Wait()
+ env.svc.RecordRead(env.ctx, sessionID, path)
+ secondRead := env.svc.LastReadTime(env.ctx, sessionID, path)
+
+ require.False(t, secondRead.Before(firstRead), "second read time should not be before first")
+ })
+}
+
+func TestService_RecordRead_DifferentSessions(t *testing.T) {
+ env := setupTest(t)
+
+ path := "/shared/file.go"
+ session1, session2 := "session-1", "session-2"
+ env.createSession(t, session1)
+ env.createSession(t, session2)
+
+ env.svc.RecordRead(env.ctx, session1, path)
+
+ lastRead1 := env.svc.LastReadTime(env.ctx, session1, path)
+ require.False(t, lastRead1.IsZero())
+
+ lastRead2 := env.svc.LastReadTime(env.ctx, session2, path)
+ require.True(t, lastRead2.IsZero(), "session 2 should not see session 1's read")
+}
+
+func TestService_RecordRead_DifferentPaths(t *testing.T) {
+ env := setupTest(t)
+
+ sessionID := "test-session-3"
+ path1, path2 := "/path/to/file1.go", "/path/to/file2.go"
+ env.createSession(t, sessionID)
+
+ env.svc.RecordRead(env.ctx, sessionID, path1)
+
+ lastRead1 := env.svc.LastReadTime(env.ctx, sessionID, path1)
+ require.False(t, lastRead1.IsZero())
+
+ lastRead2 := env.svc.LastReadTime(env.ctx, sessionID, path2)
+ require.True(t, lastRead2.IsZero(), "path2 should not be recorded")
+}
@@ -1,6 +1,7 @@
package editor
import (
+ "context"
"fmt"
"math/rand"
"net/http"
@@ -17,7 +18,6 @@ import (
tea "charm.land/bubbletea/v2"
"charm.land/lipgloss/v2"
"github.com/charmbracelet/crush/internal/app"
- "github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/session"
@@ -66,6 +66,7 @@ type editorCmp struct {
x, y int
app *app.App
session session.Session
+ sessionFileReads []string
textarea textarea.Model
attachments []message.Attachment
deleteMode bool
@@ -181,6 +182,9 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
var cmd tea.Cmd
var cmds []tea.Cmd
switch msg := msg.(type) {
+ case chat.SessionClearedMsg:
+ m.session = session.Session{}
+ m.sessionFileReads = nil
case tea.WindowSizeMsg:
return m, m.repositionCompletions
case filepicker.FilePickedMsg:
@@ -212,19 +216,27 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) {
m.completionsStartIndex = 0
}
absPath, _ := filepath.Abs(item.Path)
+
+ ctx := context.Background()
+
// Skip attachment if file was already read and hasn't been modified.
- lastRead := filetracker.LastReadTime(absPath)
- if !lastRead.IsZero() {
- if info, err := os.Stat(item.Path); err == nil && !info.ModTime().After(lastRead) {
- return m, nil
+ if m.session.ID != "" {
+ lastRead := m.app.FileTracker.LastReadTime(ctx, m.session.ID, absPath)
+ if !lastRead.IsZero() {
+ if info, err := os.Stat(item.Path); err == nil && !info.ModTime().After(lastRead) {
+ return m, nil
+ }
}
+ } else if slices.Contains(m.sessionFileReads, absPath) {
+ return m, nil
}
+
+ m.sessionFileReads = append(m.sessionFileReads, absPath)
content, err := os.ReadFile(item.Path)
if err != nil {
// if it fails, let the LLM handle it later.
return m, nil
}
- filetracker.RecordRead(absPath)
m.attachments = append(m.attachments, message.Attachment{
FilePath: item.Path,
FileName: filepath.Base(item.Path),
@@ -662,6 +674,9 @@ func (c *editorCmp) Bindings() []key.Binding {
// we need to move some functionality to the page level
func (c *editorCmp) SetSession(session session.Session) tea.Cmd {
c.session = session
+ for _, path := range c.sessionFileReads {
+ c.app.FileTracker.RecordRead(context.Background(), session.ID, path)
+ }
return nil
}
@@ -327,6 +327,9 @@ func (p *chatPage) Update(msg tea.Msg) (util.Model, tea.Cmd) {
u, cmd = p.chat.Update(msg)
p.chat = u.(chat.MessageListCmp)
cmds = append(cmds, cmd)
+ u, cmd = p.editor.Update(msg)
+ p.editor = u.(editor.Editor)
+ cmds = append(cmds, cmd)
return p, tea.Batch(cmds...)
case filepicker.FilePickedMsg,
completions.CompletionsClosedMsg,
@@ -28,7 +28,6 @@ import (
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/commands"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/filetracker"
"github.com/charmbracelet/crush/internal/fsext"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/home"
@@ -118,6 +117,9 @@ type UI struct {
session *session.Session
sessionFiles []SessionFile
+ // keeps track of read files while we don't have a session id
+ sessionFileReads []string
+
lastUserMessageTime int64
// The width and height of the terminal in cells.
@@ -2414,21 +2416,27 @@ func (m *UI) insertFileCompletion(path string) tea.Cmd {
return func() tea.Msg {
absPath, _ := filepath.Abs(path)
- // Skip attachment if file was already read and hasn't been modified.
- lastRead := filetracker.LastReadTime(absPath)
- if !lastRead.IsZero() {
- if info, err := os.Stat(path); err == nil && !info.ModTime().After(lastRead) {
- return nil
+
+ if m.hasSession() {
+ // Skip attachment if file was already read and hasn't been modified.
+ lastRead := m.com.App.FileTracker.LastReadTime(context.Background(), m.session.ID, absPath)
+ if !lastRead.IsZero() {
+ if info, err := os.Stat(path); err == nil && !info.ModTime().After(lastRead) {
+ return nil
+ }
}
+ } else if slices.Contains(m.sessionFileReads, absPath) {
+ return nil
}
+ m.sessionFileReads = append(m.sessionFileReads, absPath)
+
// Add file as attachment.
content, err := os.ReadFile(path)
if err != nil {
// If it fails, let the LLM handle it later.
return nil
}
- filetracker.RecordRead(absPath)
return message.Attachment{
FilePath: path,
@@ -2555,6 +2563,10 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea.
m.setState(uiChat, m.focus)
}
+ for _, path := range m.sessionFileReads {
+ m.com.App.FileTracker.RecordRead(context.Background(), m.session.ID, path)
+ }
+
// Capture session ID to avoid race with main goroutine updating m.session.
sessionID := m.session.ID
cmds = append(cmds, func() tea.Msg {
@@ -2801,6 +2813,7 @@ func (m *UI) newSession() tea.Cmd {
m.session = nil
m.sessionFiles = nil
+ m.sessionFileReads = nil
m.setState(uiLanding, uiFocusEditor)
m.textarea.Focus()
m.chat.Blur()