1package history
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/db"
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/google/uuid"
12)
13
14const (
15 InitialVersion = 0
16)
17
18type File struct {
19 ID string
20 SessionID string
21 Path string
22 Content string
23 Version int64
24 CreatedAt int64
25 UpdatedAt int64
26}
27
28// Service manages file versions and history for sessions.
29type Service interface {
30 pubsub.Subscriber[File]
31 Create(ctx context.Context, sessionID, path, content string) (File, error)
32
33 // CreateVersion creates a new version of a file.
34 CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
35
36 Get(ctx context.Context, id string) (File, error)
37 GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
38 ListBySession(ctx context.Context, sessionID string) ([]File, error)
39 ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
40 Delete(ctx context.Context, id string) error
41 DeleteSessionFiles(ctx context.Context, sessionID string) error
42}
43
44type service struct {
45 *pubsub.Broker[File]
46 db *sql.DB
47 q *db.Queries
48}
49
50func NewService(q *db.Queries, db *sql.DB) Service {
51 return &service{
52 Broker: pubsub.NewBroker[File](),
53 q: q,
54 db: db,
55 }
56}
57
58func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
59 return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
60}
61
62// CreateVersion creates a new version of a file with auto-incremented version
63// number. If no previous versions exist for the path, it creates the initial
64// version. The provided content is stored as the new version.
65func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
66 // Get the latest version for this path
67 files, err := s.q.ListFilesByPath(ctx, path)
68 if err != nil {
69 return File{}, err
70 }
71
72 if len(files) == 0 {
73 // No previous versions, create initial
74 return s.Create(ctx, sessionID, path, content)
75 }
76
77 // Get the latest version
78 latestFile := files[0] // Files are ordered by version DESC, created_at DESC
79 nextVersion := latestFile.Version + 1
80
81 return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
82}
83
84func (s *service) createWithVersion(ctx context.Context, sessionID, path, content string, version int64) (File, error) {
85 // Maximum number of retries for transaction conflicts
86 const maxRetries = 3
87 var file File
88 var err error
89
90 // Retry loop for transaction conflicts
91 for attempt := range maxRetries {
92 // Start a transaction
93 tx, txErr := s.db.BeginTx(ctx, nil)
94 if txErr != nil {
95 return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
96 }
97
98 // Create a new queries instance with the transaction
99 qtx := s.q.WithTx(tx)
100
101 // Try to create the file within the transaction
102 dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
103 ID: uuid.New().String(),
104 SessionID: sessionID,
105 Path: path,
106 Content: content,
107 Version: version,
108 })
109 if txErr != nil {
110 // Rollback the transaction
111 tx.Rollback()
112
113 // Check if this is a uniqueness constraint violation
114 if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
115 if attempt < maxRetries-1 {
116 // If we have retries left, increment version and try again
117 version++
118 continue
119 }
120 }
121 return File{}, txErr
122 }
123
124 // Commit the transaction
125 if txErr = tx.Commit(); txErr != nil {
126 return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
127 }
128
129 file = s.fromDBItem(dbFile)
130 s.Publish(pubsub.CreatedEvent, file)
131 return file, nil
132 }
133
134 return file, err
135}
136
137func (s *service) Get(ctx context.Context, id string) (File, error) {
138 dbFile, err := s.q.GetFile(ctx, id)
139 if err != nil {
140 return File{}, err
141 }
142 return s.fromDBItem(dbFile), nil
143}
144
145func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
146 dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
147 Path: path,
148 SessionID: sessionID,
149 })
150 if err != nil {
151 return File{}, err
152 }
153 return s.fromDBItem(dbFile), nil
154}
155
156func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
157 dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
158 if err != nil {
159 return nil, err
160 }
161 files := make([]File, len(dbFiles))
162 for i, dbFile := range dbFiles {
163 files[i] = s.fromDBItem(dbFile)
164 }
165 return files, nil
166}
167
168func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
169 dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
170 if err != nil {
171 return nil, err
172 }
173 files := make([]File, len(dbFiles))
174 for i, dbFile := range dbFiles {
175 files[i] = s.fromDBItem(dbFile)
176 }
177 return files, nil
178}
179
180func (s *service) Delete(ctx context.Context, id string) error {
181 file, err := s.Get(ctx, id)
182 if err != nil {
183 return err
184 }
185 err = s.q.DeleteFile(ctx, id)
186 if err != nil {
187 return err
188 }
189 s.Publish(pubsub.DeletedEvent, file)
190 return nil
191}
192
193func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
194 files, err := s.ListBySession(ctx, sessionID)
195 if err != nil {
196 return err
197 }
198 for _, file := range files {
199 err = s.Delete(ctx, file.ID)
200 if err != nil {
201 return err
202 }
203 }
204 return nil
205}
206
207func (s *service) fromDBItem(item db.File) File {
208 return File{
209 ID: item.ID,
210 SessionID: item.SessionID,
211 Path: item.Path,
212 Content: item.Content,
213 Version: item.Version,
214 CreatedAt: item.CreatedAt,
215 UpdatedAt: item.UpdatedAt,
216 }
217}