diff --git a/internal/agent/agentic_fetch_tool.go b/internal/agent/agentic_fetch_tool.go index 89d3535720f8452111f12f4df4eb691e39253bed..08da0e870187f537c9c88ac6a2b6ada97ff6fc88 100644 --- a/internal/agent/agentic_fetch_tool.go +++ b/internal/agent/agentic_fetch_tool.go @@ -168,7 +168,7 @@ func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) ( tools.NewGlobTool(tmpDir), tools.NewGrepTool(tmpDir), tools.NewSourcegraphTool(client), - tools.NewViewTool(c.lspClients, c.permissions, tmpDir), + tools.NewViewTool(c.lspClients, c.permissions, c.filetracker, tmpDir), } agent := NewSessionAgent(SessionAgentOptions{ diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 3f4e8daddbd4de34e788bce59a9573c00d940252..2bb5e5650bcb3280ddb95bdcea7d588a2eea7643 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -20,6 +20,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" @@ -37,6 +38,7 @@ type fakeEnv struct { messages message.Service permissions permission.Service history history.Service + filetracker *filetracker.Service lspClients *csync.Map[string, *lsp.Client] } @@ -117,6 +119,7 @@ func testEnv(t *testing.T) fakeEnv { permissions := permission.NewPermissionService(workingDir, true, []string{}) history := history.NewService(q, conn) + filetrackerService := filetracker.NewService(q) lspClients := csync.NewMap[string, *lsp.Client]() t.Cleanup(func() { @@ -130,6 +133,7 @@ func testEnv(t *testing.T) fakeEnv { messages, permissions, history, + &filetrackerService, lspClients, } } @@ -200,15 +204,15 @@ func coderAgent(r *vcr.Recorder, env fakeEnv, large, small fantasy.LanguageModel allTools := []fantasy.AgentTool{ tools.NewBashTool(env.permissions, env.workingDir, cfg.Options.Attribution, modelName), tools.NewDownloadTool(env.permissions, env.workingDir, r.GetDefaultClient()), - tools.NewEditTool(env.lspClients, env.permissions, env.history, env.workingDir), - tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, env.workingDir), + tools.NewEditTool(env.lspClients, env.permissions, env.history, *env.filetracker, env.workingDir), + tools.NewMultiEditTool(env.lspClients, env.permissions, env.history, *env.filetracker, env.workingDir), tools.NewFetchTool(env.permissions, env.workingDir, r.GetDefaultClient()), tools.NewGlobTool(env.workingDir), tools.NewGrepTool(env.workingDir), tools.NewLsTool(env.permissions, env.workingDir, cfg.Tools.Ls), tools.NewSourcegraphTool(r.GetDefaultClient()), - tools.NewViewTool(env.lspClients, env.permissions, env.workingDir), - tools.NewWriteTool(env.lspClients, env.permissions, env.history, env.workingDir), + tools.NewViewTool(env.lspClients, env.permissions, *env.filetracker, env.workingDir), + tools.NewWriteTool(env.lspClients, env.permissions, env.history, *env.filetracker, env.workingDir), } return testSessionAgent(env, large, small, systemPrompt, allTools...), nil diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 8c2a785b2f8ffeb77bbf52bb9653e8a98369303b..fd65072fd4eb297b8eddcb38aafe50d595601f82 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -22,6 +22,7 @@ import ( "github.com/charmbracelet/crush/internal/agent/tools" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/log" "github.com/charmbracelet/crush/internal/lsp" @@ -64,6 +65,7 @@ type coordinator struct { messages message.Service permissions permission.Service history history.Service + filetracker filetracker.Service lspClients *csync.Map[string, *lsp.Client] currentAgent SessionAgent @@ -79,6 +81,7 @@ func NewCoordinator( messages message.Service, permissions permission.Service, history history.Service, + filetracker filetracker.Service, lspClients *csync.Map[string, *lsp.Client], ) (Coordinator, error) { c := &coordinator{ @@ -87,6 +90,7 @@ func NewCoordinator( messages: messages, permissions: permissions, history: history, + filetracker: filetracker, lspClients: lspClients, agents: make(map[string]SessionAgent), } @@ -393,16 +397,16 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan tools.NewJobOutputTool(), tools.NewJobKillTool(), tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil), - tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()), - tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()), + tools.NewEditTool(c.lspClients, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), + tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil), tools.NewGlobTool(c.cfg.WorkingDir()), tools.NewGrepTool(c.cfg.WorkingDir()), tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls), tools.NewSourcegraphTool(nil), tools.NewTodosTool(c.sessions), - tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...), - tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()), + tools.NewViewTool(c.lspClients, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...), + tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()), ) if len(c.cfg.LSP) > 0 { diff --git a/internal/agent/tools/edit.go b/internal/agent/tools/edit.go index 2c9b15abfe148fb881ee90f75f207c1134776281..74b84c784796a97db2f379cf61fb3eb8b18934d4 100644 --- a/internal/agent/tools/edit.go +++ b/internal/agent/tools/edit.go @@ -56,10 +56,17 @@ type editContext struct { ctx context.Context permissions permission.Service files history.Service + filetracker filetracker.Service workingDir string } -func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool { +func NewEditTool( + lspClients *csync.Map[string, *lsp.Client], + permissions permission.Service, + files history.Service, + filetracker filetracker.Service, + workingDir string, +) fantasy.AgentTool { return fantasy.NewAgentTool( EditToolName, string(editDescription), @@ -73,7 +80,7 @@ func NewEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss var response fantasy.ToolResponse var err error - editCtx := editContext{ctx, permissions, files, workingDir} + editCtx := editContext{ctx, permissions, files, filetracker, workingDir} if params.OldString == "" { response, err = createNewFile(editCtx, params.FilePath, params.NewString, call) @@ -168,8 +175,7 @@ func createNewFile(edit editContext, filePath, content string, call fantasy.Tool slog.Error("Error creating file history version", "error", err) } - filetracker.RecordWrite(filePath) - filetracker.RecordRead(filePath) + edit.filetracker.RecordRead(edit.ctx, sessionID, filePath) return fantasy.WithResponseMetadata( fantasy.NewTextResponse("File created: "+filePath), @@ -195,12 +201,17 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil } - if filetracker.LastReadTime(filePath).IsZero() { + sessionID := GetSessionFromContext(edit.ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for deleting content") + } + + lastRead := edit.filetracker.LastReadTime(edit.ctx, sessionID, filePath) + if lastRead.IsZero() { return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil } - modTime := fileInfo.ModTime() - lastRead := filetracker.LastReadTime(filePath) + modTime := fileInfo.ModTime().Truncate(time.Second) if modTime.After(lastRead) { return fantasy.NewTextErrorResponse( fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", @@ -236,12 +247,6 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool newContent = oldContent[:index] + oldContent[index+len(oldString):] } - sessionID := GetSessionFromContext(edit.ctx) - - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for deleting content") - } - _, additions, removals := diff.GenerateDiff( oldContent, newContent, @@ -301,8 +306,7 @@ func deleteContent(edit editContext, filePath, oldString string, replaceAll bool slog.Error("Error creating file history version", "error", err) } - filetracker.RecordWrite(filePath) - filetracker.RecordRead(filePath) + edit.filetracker.RecordRead(edit.ctx, sessionID, filePath) return fantasy.WithResponseMetadata( fantasy.NewTextResponse("Content deleted from file: "+filePath), @@ -328,12 +332,17 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", filePath)), nil } - if filetracker.LastReadTime(filePath).IsZero() { + sessionID := GetSessionFromContext(edit.ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for edit a file") + } + + lastRead := edit.filetracker.LastReadTime(edit.ctx, sessionID, filePath) + if lastRead.IsZero() { return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil } - modTime := fileInfo.ModTime() - lastRead := filetracker.LastReadTime(filePath) + modTime := fileInfo.ModTime().Truncate(time.Second) if modTime.After(lastRead) { return fantasy.NewTextErrorResponse( fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", @@ -369,11 +378,6 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep if oldContent == newContent { return fantasy.NewTextErrorResponse("new content is the same as old content. No changes made."), nil } - sessionID := GetSessionFromContext(edit.ctx) - - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file") - } _, additions, removals := diff.GenerateDiff( oldContent, newContent, @@ -433,8 +437,7 @@ func replaceContent(edit editContext, filePath, oldString, newString string, rep slog.Error("Error creating file history version", "error", err) } - filetracker.RecordWrite(filePath) - filetracker.RecordRead(filePath) + edit.filetracker.RecordRead(edit.ctx, sessionID, filePath) return fantasy.WithResponseMetadata( fantasy.NewTextResponse("Content replaced in file: "+filePath), diff --git a/internal/agent/tools/multiedit.go b/internal/agent/tools/multiedit.go index 0640228d23230e6a49d8e1405f371c099031fbf7..48736ebf311230a28b51702e0ddd3ff8df19b284 100644 --- a/internal/agent/tools/multiedit.go +++ b/internal/agent/tools/multiedit.go @@ -58,7 +58,13 @@ const MultiEditToolName = "multiedit" //go:embed multiedit.md var multieditDescription []byte -func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool { +func NewMultiEditTool( + lspClients *csync.Map[string, *lsp.Client], + permissions permission.Service, + files history.Service, + filetracker filetracker.Service, + workingDir string, +) fantasy.AgentTool { return fantasy.NewAgentTool( MultiEditToolName, string(multieditDescription), @@ -81,7 +87,7 @@ func NewMultiEditTool(lspClients *csync.Map[string, *lsp.Client], permissions pe var response fantasy.ToolResponse var err error - editCtx := editContext{ctx, permissions, files, workingDir} + editCtx := editContext{ctx, permissions, files, filetracker, workingDir} // Handle file creation case (first edit has empty old_string) if len(params.Edits) > 0 && params.Edits[0].OldString == "" { response, err = processMultiEditWithCreation(editCtx, params, call) @@ -210,8 +216,7 @@ func processMultiEditWithCreation(edit editContext, params MultiEditParams, call slog.Error("Error creating file history version", "error", err) } - filetracker.RecordWrite(params.FilePath) - filetracker.RecordRead(params.FilePath) + edit.filetracker.RecordRead(edit.ctx, sessionID, params.FilePath) var message string if len(failedEdits) > 0 { @@ -247,14 +252,19 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call return fantasy.NewTextErrorResponse(fmt.Sprintf("path is a directory, not a file: %s", params.FilePath)), nil } + sessionID := GetSessionFromContext(edit.ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for editing file") + } + // Check if file was read before editing - if filetracker.LastReadTime(params.FilePath).IsZero() { + lastRead := edit.filetracker.LastReadTime(edit.ctx, sessionID, params.FilePath) + if lastRead.IsZero() { return fantasy.NewTextErrorResponse("you must read the file before editing it. Use the View tool first"), nil } - // Check if file was modified since last read - modTime := fileInfo.ModTime() - lastRead := filetracker.LastReadTime(params.FilePath) + // Check if file was modified since last read. + modTime := fileInfo.ModTime().Truncate(time.Second) if modTime.After(lastRead) { return fantasy.NewTextErrorResponse( fmt.Sprintf("file %s has been modified since it was last read (mod time: %s, last read: %s)", @@ -301,12 +311,6 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call return fantasy.NewTextErrorResponse("no changes made - all edits resulted in identical content"), nil } - // Get session and message IDs - sessionID := GetSessionFromContext(edit.ctx) - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for editing file") - } - // Generate diff and check permissions _, additions, removals := diff.GenerateDiff(oldContent, currentContent, strings.TrimPrefix(params.FilePath, edit.workingDir)) @@ -369,8 +373,7 @@ func processMultiEditExistingFile(edit editContext, params MultiEditParams, call slog.Error("Error creating file history version", "error", err) } - filetracker.RecordWrite(params.FilePath) - filetracker.RecordRead(params.FilePath) + edit.filetracker.RecordRead(edit.ctx, sessionID, params.FilePath) var message string if len(failedEdits) > 0 { diff --git a/internal/agent/tools/multiedit_test.go b/internal/agent/tools/multiedit_test.go index b6d575435e63dcd62a4dc9a7efb76cf13c14ad05..1ca2a6f7689e345ac944889f1f92284de0652f90 100644 --- a/internal/agent/tools/multiedit_test.go +++ b/internal/agent/tools/multiedit_test.go @@ -6,10 +6,7 @@ import ( "path/filepath" "testing" - "github.com/charmbracelet/crush/internal/csync" - "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/history" - "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/permission" "github.com/charmbracelet/crush/internal/pubsub" "github.com/stretchr/testify/require" @@ -111,17 +108,6 @@ func TestMultiEditSequentialApplication(t *testing.T) { err := os.WriteFile(testFile, []byte(content), 0o644) require.NoError(t, err) - // Mock components. - lspClients := csync.NewMap[string, *lsp.Client]() - permissions := &mockPermissionService{Broker: pubsub.NewBroker[permission.PermissionRequest]()} - files := &mockHistoryService{Broker: pubsub.NewBroker[history.File]()} - - // Create multiedit tool. - _ = NewMultiEditTool(lspClients, permissions, files, tmpDir) - - // Simulate reading the file first. - filetracker.RecordRead(testFile) - // Manually test the sequential application logic. currentContent := content diff --git a/internal/agent/tools/view.go b/internal/agent/tools/view.go index 35865cf43f7c587d60764b3ed177374940bbe2dc..b26267fcef3b296babc3c9dbcee64336ef162b75 100644 --- a/internal/agent/tools/view.go +++ b/internal/agent/tools/view.go @@ -47,7 +47,13 @@ const ( MaxLineLength = 2000 ) -func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, workingDir string, skillsPaths ...string) fantasy.AgentTool { +func NewViewTool( + lspClients *csync.Map[string, *lsp.Client], + permissions permission.Service, + filetracker filetracker.Service, + workingDir string, + skillsPaths ...string, +) fantasy.AgentTool { return fantasy.NewAgentTool( ViewToolName, string(viewDescription), @@ -74,13 +80,13 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss isOutsideWorkDir := err != nil || strings.HasPrefix(relPath, "..") isSkillFile := isInSkillsPath(absFilePath, skillsPaths) + sessionID := GetSessionFromContext(ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory") + } + // Request permission for files outside working directory, unless it's a skill file. if isOutsideWorkDir && !isSkillFile { - sessionID := GetSessionFromContext(ctx) - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for accessing files outside working directory") - } - granted, err := permissions.Request(ctx, permission.CreatePermissionRequest{ SessionID: sessionID, @@ -190,7 +196,7 @@ func NewViewTool(lspClients *csync.Map[string, *lsp.Client], permissions permiss } output += "\n\n" output += getDiagnostics(filePath, lspClients) - filetracker.RecordRead(filePath) + filetracker.RecordRead(ctx, sessionID, filePath) return fantasy.WithResponseMetadata( fantasy.NewTextResponse(output), ViewResponseMetadata{ diff --git a/internal/agent/tools/write.go b/internal/agent/tools/write.go index 8becaea3c08157897dcece7b3d5d4de5cb2ee929..c2f5c7d1c83efd0731e8623c1e9cbb98b9bfdd2f 100644 --- a/internal/agent/tools/write.go +++ b/internal/agent/tools/write.go @@ -44,7 +44,13 @@ type WriteResponseMetadata struct { const WriteToolName = "write" -func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permission.Service, files history.Service, workingDir string) fantasy.AgentTool { +func NewWriteTool( + lspClients *csync.Map[string, *lsp.Client], + permissions permission.Service, + files history.Service, + filetracker filetracker.Service, + workingDir string, +) fantasy.AgentTool { return fantasy.NewAgentTool( WriteToolName, string(writeDescription), @@ -57,6 +63,11 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis return fantasy.NewTextErrorResponse("content is required"), nil } + sessionID := GetSessionFromContext(ctx) + if sessionID == "" { + return fantasy.ToolResponse{}, fmt.Errorf("session_id is required") + } + filePath := filepathext.SmartJoin(workingDir, params.FilePath) fileInfo, err := os.Stat(filePath) @@ -65,8 +76,8 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis return fantasy.NewTextErrorResponse(fmt.Sprintf("Path is a directory, not a file: %s", filePath)), nil } - modTime := fileInfo.ModTime() - lastRead := filetracker.LastReadTime(filePath) + modTime := fileInfo.ModTime().Truncate(time.Second) + lastRead := filetracker.LastReadTime(ctx, sessionID, filePath) if modTime.After(lastRead) { return fantasy.NewTextErrorResponse(fmt.Sprintf("File %s has been modified since it was last read.\nLast modification: %s\nLast read: %s\n\nPlease read the file again before modifying it.", filePath, modTime.Format(time.RFC3339), lastRead.Format(time.RFC3339))), nil @@ -93,11 +104,6 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis } } - sessionID := GetSessionFromContext(ctx) - if sessionID == "" { - return fantasy.ToolResponse{}, fmt.Errorf("session_id is required") - } - diff, additions, removals := diff.GenerateDiff( oldContent, params.Content, @@ -153,8 +159,7 @@ func NewWriteTool(lspClients *csync.Map[string, *lsp.Client], permissions permis slog.Error("Error creating file history version", "error", err) } - filetracker.RecordWrite(filePath) - filetracker.RecordRead(filePath) + filetracker.RecordRead(ctx, sessionID, filePath) notifyLSPs(ctx, lspClients, params.FilePath) diff --git a/internal/app/app.go b/internal/app/app.go index ef6e636e44eeea9407557ca48f8ba9bd8eba72b2..647d90c9cfe29402b00ef5743f3a84f5e1b681ab 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -23,6 +23,7 @@ import ( "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/csync" "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/format" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/log" @@ -53,6 +54,7 @@ type App struct { Messages message.Service History history.Service Permissions permission.Service + FileTracker filetracker.Service AgentCoordinator agent.Coordinator @@ -87,6 +89,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { Messages: messages, History: files, Permissions: permission.NewPermissionService(cfg.WorkingDir(), skipPermissionsRequests, allowedTools), + FileTracker: filetracker.NewService(q), LSPClients: csync.NewMap[string, *lsp.Client](), globalCtx: ctx, @@ -460,6 +463,7 @@ func (app *App) InitCoderAgent(ctx context.Context) error { app.Messages, app.Permissions, app.History, + app.FileTracker, app.LSPClients, ) if err != nil { diff --git a/internal/db/db.go b/internal/db/db.go index a4e430c720f33f4cd3c0b9710633595ef5c5fa1f..739c2087e1c1e125875d5006c86f85de37fed3be 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -57,6 +57,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.getFileByPathAndSessionStmt, err = db.PrepareContext(ctx, getFileByPathAndSession); err != nil { return nil, fmt.Errorf("error preparing query GetFileByPathAndSession: %w", err) } + if q.getFileReadStmt, err = db.PrepareContext(ctx, getFileRead); err != nil { + return nil, fmt.Errorf("error preparing query GetFileRead: %w", err) + } if q.getHourDayHeatmapStmt, err = db.PrepareContext(ctx, getHourDayHeatmap); err != nil { return nil, fmt.Errorf("error preparing query GetHourDayHeatmap: %w", err) } @@ -111,6 +114,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.listUserMessagesBySessionStmt, err = db.PrepareContext(ctx, listUserMessagesBySession); err != nil { return nil, fmt.Errorf("error preparing query ListUserMessagesBySession: %w", err) } + if q.recordFileReadStmt, err = db.PrepareContext(ctx, recordFileRead); err != nil { + return nil, fmt.Errorf("error preparing query RecordFileRead: %w", err) + } if q.updateMessageStmt, err = db.PrepareContext(ctx, updateMessage); err != nil { return nil, fmt.Errorf("error preparing query UpdateMessage: %w", err) } @@ -180,6 +186,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing getFileByPathAndSessionStmt: %w", cerr) } } + if q.getFileReadStmt != nil { + if cerr := q.getFileReadStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing getFileReadStmt: %w", cerr) + } + } if q.getHourDayHeatmapStmt != nil { if cerr := q.getHourDayHeatmapStmt.Close(); cerr != nil { err = fmt.Errorf("error closing getHourDayHeatmapStmt: %w", cerr) @@ -270,6 +281,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing listUserMessagesBySessionStmt: %w", cerr) } } + if q.recordFileReadStmt != nil { + if cerr := q.recordFileReadStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing recordFileReadStmt: %w", cerr) + } + } if q.updateMessageStmt != nil { if cerr := q.updateMessageStmt.Close(); cerr != nil { err = fmt.Errorf("error closing updateMessageStmt: %w", cerr) @@ -335,6 +351,7 @@ type Queries struct { getAverageResponseTimeStmt *sql.Stmt getFileStmt *sql.Stmt getFileByPathAndSessionStmt *sql.Stmt + getFileReadStmt *sql.Stmt getHourDayHeatmapStmt *sql.Stmt getMessageStmt *sql.Stmt getRecentActivityStmt *sql.Stmt @@ -353,6 +370,7 @@ type Queries struct { listNewFilesStmt *sql.Stmt listSessionsStmt *sql.Stmt listUserMessagesBySessionStmt *sql.Stmt + recordFileReadStmt *sql.Stmt updateMessageStmt *sql.Stmt updateSessionStmt *sql.Stmt updateSessionTitleAndUsageStmt *sql.Stmt @@ -373,6 +391,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { getAverageResponseTimeStmt: q.getAverageResponseTimeStmt, getFileStmt: q.getFileStmt, getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, + getFileReadStmt: q.getFileReadStmt, getHourDayHeatmapStmt: q.getHourDayHeatmapStmt, getMessageStmt: q.getMessageStmt, getRecentActivityStmt: q.getRecentActivityStmt, @@ -391,6 +410,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries { listNewFilesStmt: q.listNewFilesStmt, listSessionsStmt: q.listSessionsStmt, listUserMessagesBySessionStmt: q.listUserMessagesBySessionStmt, + recordFileReadStmt: q.recordFileReadStmt, updateMessageStmt: q.updateMessageStmt, updateSessionStmt: q.updateSessionStmt, updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt, diff --git a/internal/db/migrations/20260127000000_add_read_files_table.sql b/internal/db/migrations/20260127000000_add_read_files_table.sql new file mode 100644 index 0000000000000000000000000000000000000000..1161f1992885fc66e309024a0d874565ea276229 --- /dev/null +++ b/internal/db/migrations/20260127000000_add_read_files_table.sql @@ -0,0 +1,20 @@ +-- +goose Up +-- +goose StatementBegin +CREATE TABLE IF NOT EXISTS read_files ( + session_id TEXT NOT NULL CHECK (session_id != ''), + path TEXT NOT NULL CHECK (path != ''), + read_at INTEGER NOT NULL, -- Unix timestamp in seconds when file was last read + FOREIGN KEY (session_id) REFERENCES sessions (id) ON DELETE CASCADE, + PRIMARY KEY (path, session_id) +); + +CREATE INDEX IF NOT EXISTS idx_read_files_session_id ON read_files (session_id); +CREATE INDEX IF NOT EXISTS idx_read_files_path ON read_files (path); +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +DROP INDEX IF EXISTS idx_read_files_path; +DROP INDEX IF EXISTS idx_read_files_session_id; +DROP TABLE IF EXISTS read_files; +-- +goose StatementEnd diff --git a/internal/db/models.go b/internal/db/models.go index 317e7c92e09c857ee610832e365af2c4ecc90181..a105074ab9e6320bd92b90121e7694b1f8cd1e5a 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -31,6 +31,12 @@ type Message struct { IsSummaryMessage int64 `json:"is_summary_message"` } +type ReadFile struct { + SessionID string `json:"session_id"` + Path string `json:"path"` + ReadAt int64 `json:"read_at"` // Unix timestamp when file was last read +} + type Session struct { ID string `json:"id"` ParentSessionID sql.NullString `json:"parent_session_id"` diff --git a/internal/db/querier.go b/internal/db/querier.go index 394ba1f71aea47c93956e91fcaf07e02f65098b8..c233fd59f63f8b46d3e6d62e1c162f47d6d34e3f 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -20,6 +20,7 @@ type Querier interface { GetAverageResponseTime(ctx context.Context) (int64, error) GetFile(ctx context.Context, id string) (File, error) GetFileByPathAndSession(ctx context.Context, arg GetFileByPathAndSessionParams) (File, error) + GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error) GetHourDayHeatmap(ctx context.Context) ([]GetHourDayHeatmapRow, error) GetMessage(ctx context.Context, id string) (Message, error) GetRecentActivity(ctx context.Context) ([]GetRecentActivityRow, error) @@ -38,6 +39,7 @@ type Querier interface { ListNewFiles(ctx context.Context) ([]File, error) ListSessions(ctx context.Context) ([]Session, error) ListUserMessagesBySession(ctx context.Context, sessionID string) ([]Message, error) + RecordFileRead(ctx context.Context, arg RecordFileReadParams) error UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error diff --git a/internal/db/read_files.sql.go b/internal/db/read_files.sql.go new file mode 100644 index 0000000000000000000000000000000000000000..b18907c1f27a3c753b6b1a2cf1ca0563c3fd78d5 --- /dev/null +++ b/internal/db/read_files.sql.go @@ -0,0 +1,57 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: read_files.sql + +package db + +import ( + "context" +) + +const getFileRead = `-- name: GetFileRead :one +SELECT session_id, path, read_at FROM read_files +WHERE session_id = ? AND path = ? LIMIT 1 +` + +type GetFileReadParams struct { + SessionID string `json:"session_id"` + Path string `json:"path"` +} + +func (q *Queries) GetFileRead(ctx context.Context, arg GetFileReadParams) (ReadFile, error) { + row := q.queryRow(ctx, q.getFileReadStmt, getFileRead, arg.SessionID, arg.Path) + var i ReadFile + err := row.Scan( + &i.SessionID, + &i.Path, + &i.ReadAt, + ) + return i, err +} + +const recordFileRead = `-- name: RecordFileRead :exec +INSERT INTO read_files ( + session_id, + path, + read_at +) VALUES ( + ?, + ?, + strftime('%s', 'now') +) ON CONFLICT(path, session_id) DO UPDATE SET + read_at = excluded.read_at +` + +type RecordFileReadParams struct { + SessionID string `json:"session_id"` + Path string `json:"path"` +} + +func (q *Queries) RecordFileRead(ctx context.Context, arg RecordFileReadParams) error { + _, err := q.exec(ctx, q.recordFileReadStmt, recordFileRead, + arg.SessionID, + arg.Path, + ) + return err +} diff --git a/internal/db/sql/read_files.sql b/internal/db/sql/read_files.sql new file mode 100644 index 0000000000000000000000000000000000000000..f607312c2ba8660aa2c7030e415ce2ca7320cd6d --- /dev/null +++ b/internal/db/sql/read_files.sql @@ -0,0 +1,15 @@ +-- name: RecordFileRead :exec +INSERT INTO read_files ( + session_id, + path, + read_at +) VALUES ( + ?, + ?, + strftime('%s', 'now') +) ON CONFLICT(path, session_id) DO UPDATE SET + read_at = excluded.read_at; + +-- name: GetFileRead :one +SELECT * FROM read_files +WHERE session_id = ? AND path = ? LIMIT 1; diff --git a/internal/filetracker/filetracker.go b/internal/filetracker/filetracker.go deleted file mode 100644 index 534a19dacdc209f7ef2d9c5b107cb5f88a665ee5..0000000000000000000000000000000000000000 --- a/internal/filetracker/filetracker.go +++ /dev/null @@ -1,70 +0,0 @@ -// Package filetracker tracks file read/write times to prevent editing files -// that haven't been read, and to detect external modifications. -// -// TODO: Consider moving this to persistent storage (e.g., the database) to -// preserve file access history across sessions. -// We would need to make sure to handle the case where we reload a session and the underlying files did change. -package filetracker - -import ( - "sync" - "time" -) - -// record tracks when a file was read/written. -type record struct { - path string - readTime time.Time - writeTime time.Time -} - -var ( - records = make(map[string]record) - recordMutex sync.RWMutex -) - -// RecordRead records when a file was read. -func RecordRead(path string) { - recordMutex.Lock() - defer recordMutex.Unlock() - - rec, exists := records[path] - if !exists { - rec = record{path: path} - } - rec.readTime = time.Now() - records[path] = rec -} - -// LastReadTime returns when a file was last read. Returns zero time if never -// read. -func LastReadTime(path string) time.Time { - recordMutex.RLock() - defer recordMutex.RUnlock() - - rec, exists := records[path] - if !exists { - return time.Time{} - } - return rec.readTime -} - -// RecordWrite records when a file was written. -func RecordWrite(path string) { - recordMutex.Lock() - defer recordMutex.Unlock() - - rec, exists := records[path] - if !exists { - rec = record{path: path} - } - rec.writeTime = time.Now() - records[path] = rec -} - -// Reset clears all file tracking records. Useful for testing. -func Reset() { - recordMutex.Lock() - defer recordMutex.Unlock() - records = make(map[string]record) -} diff --git a/internal/filetracker/service.go b/internal/filetracker/service.go new file mode 100644 index 0000000000000000000000000000000000000000..8f080d124e49dfc32f43796194c09ac22beaa9f1 --- /dev/null +++ b/internal/filetracker/service.go @@ -0,0 +1,70 @@ +// Package filetracker provides functionality to track file reads in sessions. +package filetracker + +import ( + "context" + "log/slog" + "os" + "path/filepath" + "time" + + "github.com/charmbracelet/crush/internal/db" +) + +// Service defines the interface for tracking file reads in sessions. +type Service interface { + // RecordRead records when a file was read. + RecordRead(ctx context.Context, sessionID, path string) + + // LastReadTime returns when a file was last read. + // Returns zero time if never read. + LastReadTime(ctx context.Context, sessionID, path string) time.Time +} + +type service struct { + q *db.Queries +} + +// NewService creates a new file tracker service. +func NewService(q *db.Queries) Service { + return &service{q: q} +} + +// RecordRead records when a file was read. +func (s *service) RecordRead(ctx context.Context, sessionID, path string) { + if err := s.q.RecordFileRead(ctx, db.RecordFileReadParams{ + SessionID: sessionID, + Path: relpath(path), + }); err != nil { + slog.Error("Error recording file read", "error", err, "file", path) + } +} + +// LastReadTime returns when a file was last read. +// Returns zero time if never read. +func (s *service) LastReadTime(ctx context.Context, sessionID, path string) time.Time { + readFile, err := s.q.GetFileRead(ctx, db.GetFileReadParams{ + SessionID: sessionID, + Path: relpath(path), + }) + if err != nil { + return time.Time{} + } + + return time.Unix(readFile.ReadAt, 0) +} + +func relpath(path string) string { + path = filepath.Clean(path) + basepath, err := os.Getwd() + if err != nil { + slog.Warn("Error getting basepath", "error", err) + return path + } + relpath, err := filepath.Rel(basepath, path) + if err != nil { + slog.Warn("Error getting relpath", "error", err) + return path + } + return relpath +} diff --git a/internal/filetracker/service_test.go b/internal/filetracker/service_test.go new file mode 100644 index 0000000000000000000000000000000000000000..c7fb15090dd31e9591c5c3b9c2a256c839aea3f6 --- /dev/null +++ b/internal/filetracker/service_test.go @@ -0,0 +1,116 @@ +package filetracker + +import ( + "context" + "testing" + "testing/synctest" + "time" + + "github.com/charmbracelet/crush/internal/db" + "github.com/stretchr/testify/require" +) + +type testEnv struct { + ctx context.Context + q *db.Queries + svc Service +} + +func setupTest(t *testing.T) *testEnv { + t.Helper() + + conn, err := db.Connect(t.Context(), t.TempDir()) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + + q := db.New(conn) + return &testEnv{ + ctx: t.Context(), + q: q, + svc: NewService(q), + } +} + +func (e *testEnv) createSession(t *testing.T, sessionID string) { + t.Helper() + _, err := e.q.CreateSession(e.ctx, db.CreateSessionParams{ + ID: sessionID, + Title: "Test Session", + }) + require.NoError(t, err) +} + +func TestService_RecordRead(t *testing.T) { + env := setupTest(t) + + sessionID := "test-session-1" + path := "/path/to/file.go" + env.createSession(t, sessionID) + + env.svc.RecordRead(env.ctx, sessionID, path) + + lastRead := env.svc.LastReadTime(env.ctx, sessionID, path) + require.False(t, lastRead.IsZero(), "expected non-zero time after recording read") + require.WithinDuration(t, time.Now(), lastRead, 2*time.Second) +} + +func TestService_LastReadTime_NotFound(t *testing.T) { + env := setupTest(t) + + lastRead := env.svc.LastReadTime(env.ctx, "nonexistent-session", "/nonexistent/path") + require.True(t, lastRead.IsZero(), "expected zero time for unread file") +} + +func TestService_RecordRead_UpdatesTimestamp(t *testing.T) { + env := setupTest(t) + + sessionID := "test-session-2" + path := "/path/to/file.go" + env.createSession(t, sessionID) + + env.svc.RecordRead(env.ctx, sessionID, path) + firstRead := env.svc.LastReadTime(env.ctx, sessionID, path) + require.False(t, firstRead.IsZero()) + + synctest.Test(t, func(t *testing.T) { + time.Sleep(100 * time.Millisecond) + synctest.Wait() + env.svc.RecordRead(env.ctx, sessionID, path) + secondRead := env.svc.LastReadTime(env.ctx, sessionID, path) + + require.False(t, secondRead.Before(firstRead), "second read time should not be before first") + }) +} + +func TestService_RecordRead_DifferentSessions(t *testing.T) { + env := setupTest(t) + + path := "/shared/file.go" + session1, session2 := "session-1", "session-2" + env.createSession(t, session1) + env.createSession(t, session2) + + env.svc.RecordRead(env.ctx, session1, path) + + lastRead1 := env.svc.LastReadTime(env.ctx, session1, path) + require.False(t, lastRead1.IsZero()) + + lastRead2 := env.svc.LastReadTime(env.ctx, session2, path) + require.True(t, lastRead2.IsZero(), "session 2 should not see session 1's read") +} + +func TestService_RecordRead_DifferentPaths(t *testing.T) { + env := setupTest(t) + + sessionID := "test-session-3" + path1, path2 := "/path/to/file1.go", "/path/to/file2.go" + env.createSession(t, sessionID) + + env.svc.RecordRead(env.ctx, sessionID, path1) + + lastRead1 := env.svc.LastReadTime(env.ctx, sessionID, path1) + require.False(t, lastRead1.IsZero()) + + lastRead2 := env.svc.LastReadTime(env.ctx, sessionID, path2) + require.True(t, lastRead2.IsZero(), "path2 should not be recorded") +} diff --git a/internal/tui/components/chat/editor/editor.go b/internal/tui/components/chat/editor/editor.go index ba832b415133305fccbefa37da6b749405feb2c6..575c23114a9115209db7a2a02e642fe5f2246541 100644 --- a/internal/tui/components/chat/editor/editor.go +++ b/internal/tui/components/chat/editor/editor.go @@ -1,6 +1,7 @@ package editor import ( + "context" "fmt" "math/rand" "net/http" @@ -17,7 +18,6 @@ import ( tea "charm.land/bubbletea/v2" "charm.land/lipgloss/v2" "github.com/charmbracelet/crush/internal/app" - "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/session" @@ -66,6 +66,7 @@ type editorCmp struct { x, y int app *app.App session session.Session + sessionFileReads []string textarea textarea.Model attachments []message.Attachment deleteMode bool @@ -181,6 +182,9 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { var cmd tea.Cmd var cmds []tea.Cmd switch msg := msg.(type) { + case chat.SessionClearedMsg: + m.session = session.Session{} + m.sessionFileReads = nil case tea.WindowSizeMsg: return m, m.repositionCompletions case filepicker.FilePickedMsg: @@ -212,19 +216,27 @@ func (m *editorCmp) Update(msg tea.Msg) (util.Model, tea.Cmd) { m.completionsStartIndex = 0 } absPath, _ := filepath.Abs(item.Path) + + ctx := context.Background() + // Skip attachment if file was already read and hasn't been modified. - lastRead := filetracker.LastReadTime(absPath) - if !lastRead.IsZero() { - if info, err := os.Stat(item.Path); err == nil && !info.ModTime().After(lastRead) { - return m, nil + if m.session.ID != "" { + lastRead := m.app.FileTracker.LastReadTime(ctx, m.session.ID, absPath) + if !lastRead.IsZero() { + if info, err := os.Stat(item.Path); err == nil && !info.ModTime().After(lastRead) { + return m, nil + } } + } else if slices.Contains(m.sessionFileReads, absPath) { + return m, nil } + + m.sessionFileReads = append(m.sessionFileReads, absPath) content, err := os.ReadFile(item.Path) if err != nil { // if it fails, let the LLM handle it later. return m, nil } - filetracker.RecordRead(absPath) m.attachments = append(m.attachments, message.Attachment{ FilePath: item.Path, FileName: filepath.Base(item.Path), @@ -662,6 +674,9 @@ func (c *editorCmp) Bindings() []key.Binding { // we need to move some functionality to the page level func (c *editorCmp) SetSession(session session.Session) tea.Cmd { c.session = session + for _, path := range c.sessionFileReads { + c.app.FileTracker.RecordRead(context.Background(), session.ID, path) + } return nil } diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index 9a4b69f5507fbb62b7ee93df6326f94cf79d22ad..bb2eb755bf80995dd41d9ac564174de5b90262bb 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -327,6 +327,9 @@ func (p *chatPage) Update(msg tea.Msg) (util.Model, tea.Cmd) { u, cmd = p.chat.Update(msg) p.chat = u.(chat.MessageListCmp) cmds = append(cmds, cmd) + u, cmd = p.editor.Update(msg) + p.editor = u.(editor.Editor) + cmds = append(cmds, cmd) return p, tea.Batch(cmds...) case filepicker.FilePickedMsg, completions.CompletionsClosedMsg, diff --git a/internal/ui/model/ui.go b/internal/ui/model/ui.go index 1f2d7f86ef1953bf97e98109cbbe5d791c94122f..e100b6605fceceded84da8cd6cfb16507ddf64a4 100644 --- a/internal/ui/model/ui.go +++ b/internal/ui/model/ui.go @@ -28,7 +28,6 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/commands" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/filetracker" "github.com/charmbracelet/crush/internal/fsext" "github.com/charmbracelet/crush/internal/history" "github.com/charmbracelet/crush/internal/home" @@ -118,6 +117,9 @@ type UI struct { session *session.Session sessionFiles []SessionFile + // keeps track of read files while we don't have a session id + sessionFileReads []string + lastUserMessageTime int64 // The width and height of the terminal in cells. @@ -2414,21 +2416,27 @@ func (m *UI) insertFileCompletion(path string) tea.Cmd { return func() tea.Msg { absPath, _ := filepath.Abs(path) - // Skip attachment if file was already read and hasn't been modified. - lastRead := filetracker.LastReadTime(absPath) - if !lastRead.IsZero() { - if info, err := os.Stat(path); err == nil && !info.ModTime().After(lastRead) { - return nil + + if m.hasSession() { + // Skip attachment if file was already read and hasn't been modified. + lastRead := m.com.App.FileTracker.LastReadTime(context.Background(), m.session.ID, absPath) + if !lastRead.IsZero() { + if info, err := os.Stat(path); err == nil && !info.ModTime().After(lastRead) { + return nil + } } + } else if slices.Contains(m.sessionFileReads, absPath) { + return nil } + m.sessionFileReads = append(m.sessionFileReads, absPath) + // Add file as attachment. content, err := os.ReadFile(path) if err != nil { // If it fails, let the LLM handle it later. return nil } - filetracker.RecordRead(absPath) return message.Attachment{ FilePath: path, @@ -2555,6 +2563,10 @@ func (m *UI) sendMessage(content string, attachments ...message.Attachment) tea. m.setState(uiChat, m.focus) } + for _, path := range m.sessionFileReads { + m.com.App.FileTracker.RecordRead(context.Background(), m.session.ID, path) + } + // Capture session ID to avoid race with main goroutine updating m.session. sessionID := m.session.ID cmds = append(cmds, func() tea.Msg { @@ -2801,6 +2813,7 @@ func (m *UI) newSession() tea.Cmd { m.session = nil m.sessionFiles = nil + m.sessionFileReads = nil m.setState(uiLanding, uiFocusEditor) m.textarea.Focus() m.chat.Blur()