1package history
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/db"
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/google/uuid"
12)
13
14const (
15 InitialVersion = 0
16)
17
18type File struct {
19 ID string `json:"id"`
20 SessionID string `json:"session_id"`
21 Path string `json:"path"`
22 Content string `json:"content"`
23 Version int64 `json:"version"`
24 CreatedAt int64 `json:"created_at"`
25 UpdatedAt int64 `json:"updated_at"`
26}
27
28type Service interface {
29 pubsub.Suscriber[File]
30 Create(ctx context.Context, sessionID, path, content string) (File, error)
31 CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
32 Get(ctx context.Context, id string) (File, error)
33 GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
34 ListBySession(ctx context.Context, sessionID string) ([]File, error)
35 ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
36 Delete(ctx context.Context, id string) error
37 DeleteSessionFiles(ctx context.Context, sessionID string) error
38}
39
40type service struct {
41 *pubsub.Broker[File]
42 db *sql.DB
43 q *db.Queries
44}
45
46func NewService(q *db.Queries, db *sql.DB) Service {
47 return &service{
48 Broker: pubsub.NewBroker[File](),
49 q: q,
50 db: db,
51 }
52}
53
54func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
55 return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
56}
57
58func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
59 // Get the latest version for this path
60 files, err := s.q.ListFilesByPath(ctx, path)
61 if err != nil {
62 return File{}, err
63 }
64
65 if len(files) == 0 {
66 // No previous versions, create initial
67 return s.Create(ctx, sessionID, path, content)
68 }
69
70 // Get the latest version
71 latestFile := files[0] // Files are ordered by version DESC, created_at DESC
72 nextVersion := latestFile.Version + 1
73
74 return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
75}
76
77func (s *service) createWithVersion(ctx context.Context, sessionID, path, content string, version int64) (File, error) {
78 // Maximum number of retries for transaction conflicts
79 const maxRetries = 3
80 var file File
81 var err error
82
83 // Retry loop for transaction conflicts
84 for attempt := range maxRetries {
85 // Start a transaction
86 tx, txErr := s.db.BeginTx(ctx, nil)
87 if txErr != nil {
88 return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
89 }
90
91 // Create a new queries instance with the transaction
92 qtx := s.q.WithTx(tx)
93
94 // Try to create the file within the transaction
95 dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
96 ID: uuid.New().String(),
97 SessionID: sessionID,
98 Path: path,
99 Content: content,
100 Version: version,
101 })
102 if txErr != nil {
103 // Rollback the transaction
104 tx.Rollback()
105
106 // Check if this is a uniqueness constraint violation
107 if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
108 if attempt < maxRetries-1 {
109 // If we have retries left, increment version and try again
110 version++
111 continue
112 }
113 }
114 return File{}, txErr
115 }
116
117 // Commit the transaction
118 if txErr = tx.Commit(); txErr != nil {
119 return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
120 }
121
122 file = s.fromDBItem(dbFile)
123 s.Publish(pubsub.CreatedEvent, file)
124 return file, nil
125 }
126
127 return file, err
128}
129
130func (s *service) Get(ctx context.Context, id string) (File, error) {
131 dbFile, err := s.q.GetFile(ctx, id)
132 if err != nil {
133 return File{}, err
134 }
135 return s.fromDBItem(dbFile), nil
136}
137
138func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
139 dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
140 Path: path,
141 SessionID: sessionID,
142 })
143 if err != nil {
144 return File{}, err
145 }
146 return s.fromDBItem(dbFile), nil
147}
148
149func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
150 dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
151 if err != nil {
152 return nil, err
153 }
154 files := make([]File, len(dbFiles))
155 for i, dbFile := range dbFiles {
156 files[i] = s.fromDBItem(dbFile)
157 }
158 return files, nil
159}
160
161func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
162 dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
163 if err != nil {
164 return nil, err
165 }
166 files := make([]File, len(dbFiles))
167 for i, dbFile := range dbFiles {
168 files[i] = s.fromDBItem(dbFile)
169 }
170 return files, nil
171}
172
173func (s *service) Delete(ctx context.Context, id string) error {
174 file, err := s.Get(ctx, id)
175 if err != nil {
176 return err
177 }
178 err = s.q.DeleteFile(ctx, id)
179 if err != nil {
180 return err
181 }
182 s.Publish(pubsub.DeletedEvent, file)
183 return nil
184}
185
186func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
187 files, err := s.ListBySession(ctx, sessionID)
188 if err != nil {
189 return err
190 }
191 for _, file := range files {
192 err = s.Delete(ctx, file.ID)
193 if err != nil {
194 return err
195 }
196 }
197 return nil
198}
199
200func (s *service) fromDBItem(item db.File) File {
201 return File{
202 ID: item.ID,
203 SessionID: item.SessionID,
204 Path: item.Path,
205 Content: item.Content,
206 Version: item.Version,
207 CreatedAt: item.CreatedAt,
208 UpdatedAt: item.UpdatedAt,
209 }
210}