1// Package filetracker provides functionality to track file reads in sessions.
2package filetracker
3
4import (
5 "context"
6 "fmt"
7 "log/slog"
8 "os"
9 "path/filepath"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/db"
13)
14
15// Service defines the interface for tracking file reads in sessions.
16type Service interface {
17 // RecordRead records when a file was read.
18 RecordRead(ctx context.Context, sessionID, path string)
19
20 // LastReadTime returns when a file was last read.
21 // Returns zero time if never read.
22 LastReadTime(ctx context.Context, sessionID, path string) time.Time
23
24 // ListReadFiles returns the paths of all files read in a session.
25 ListReadFiles(ctx context.Context, sessionID string) ([]string, error)
26}
27
28type service struct {
29 q *db.Queries
30}
31
32// NewService creates a new file tracker service.
33func NewService(q *db.Queries) Service {
34 return &service{q: q}
35}
36
37// RecordRead records when a file was read.
38func (s *service) RecordRead(ctx context.Context, sessionID, path string) {
39 if err := s.q.RecordFileRead(ctx, db.RecordFileReadParams{
40 SessionID: sessionID,
41 Path: relpath(path),
42 }); err != nil {
43 slog.Error("Error recording file read", "error", err, "file", path)
44 }
45}
46
47// LastReadTime returns when a file was last read.
48// Returns zero time if never read.
49func (s *service) LastReadTime(ctx context.Context, sessionID, path string) time.Time {
50 readFile, err := s.q.GetFileRead(ctx, db.GetFileReadParams{
51 SessionID: sessionID,
52 Path: relpath(path),
53 })
54 if err != nil {
55 return time.Time{}
56 }
57
58 return time.Unix(readFile.ReadAt, 0)
59}
60
61func relpath(path string) string {
62 path = filepath.Clean(path)
63 basepath, err := os.Getwd()
64 if err != nil {
65 slog.Warn("Error getting basepath", "error", err)
66 return path
67 }
68 relpath, err := filepath.Rel(basepath, path)
69 if err != nil {
70 slog.Warn("Error getting relpath", "error", err)
71 return path
72 }
73 return relpath
74}
75
76// ListReadFiles returns the paths of all files read in a session.
77func (s *service) ListReadFiles(ctx context.Context, sessionID string) ([]string, error) {
78 readFiles, err := s.q.ListSessionReadFiles(ctx, sessionID)
79 if err != nil {
80 return nil, fmt.Errorf("listing read files: %w", err)
81 }
82
83 basepath, err := os.Getwd()
84 if err != nil {
85 return nil, fmt.Errorf("getting working directory: %w", err)
86 }
87
88 paths := make([]string, 0, len(readFiles))
89 for _, rf := range readFiles {
90 paths = append(paths, filepath.Join(basepath, rf.Path))
91 }
92 return paths, nil
93}