1package history
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/db"
10 "github.com/charmbracelet/crush/internal/proto"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15const (
16 InitialVersion = 0
17)
18
19type File = proto.File
20
21type Service interface {
22 pubsub.Suscriber[File]
23 Create(ctx context.Context, sessionID, path, content string) (File, error)
24 CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
25 Get(ctx context.Context, id string) (File, error)
26 GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
27 ListBySession(ctx context.Context, sessionID string) ([]File, error)
28 ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
29 Delete(ctx context.Context, id string) error
30 DeleteSessionFiles(ctx context.Context, sessionID string) error
31}
32
33type service struct {
34 *pubsub.Broker[File]
35 db *sql.DB
36 q *db.Queries
37}
38
39func NewService(q *db.Queries, db *sql.DB) Service {
40 return &service{
41 Broker: pubsub.NewBroker[File](),
42 q: q,
43 db: db,
44 }
45}
46
47func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
48 return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
49}
50
51func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
52 // Get the latest version for this path
53 files, err := s.q.ListFilesByPath(ctx, path)
54 if err != nil {
55 return File{}, err
56 }
57
58 if len(files) == 0 {
59 // No previous versions, create initial
60 return s.Create(ctx, sessionID, path, content)
61 }
62
63 // Get the latest version
64 latestFile := files[0] // Files are ordered by version DESC, created_at DESC
65 nextVersion := latestFile.Version + 1
66
67 return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
68}
69
70func (s *service) createWithVersion(ctx context.Context, sessionID, path, content string, version int64) (File, error) {
71 // Maximum number of retries for transaction conflicts
72 const maxRetries = 3
73 var file File
74 var err error
75
76 // Retry loop for transaction conflicts
77 for attempt := range maxRetries {
78 // Start a transaction
79 tx, txErr := s.db.BeginTx(ctx, nil)
80 if txErr != nil {
81 return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
82 }
83
84 // Create a new queries instance with the transaction
85 qtx := s.q.WithTx(tx)
86
87 // Try to create the file within the transaction
88 dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
89 ID: uuid.New().String(),
90 SessionID: sessionID,
91 Path: path,
92 Content: content,
93 Version: version,
94 })
95 if txErr != nil {
96 // Rollback the transaction
97 tx.Rollback()
98
99 // Check if this is a uniqueness constraint violation
100 if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
101 if attempt < maxRetries-1 {
102 // If we have retries left, increment version and try again
103 version++
104 continue
105 }
106 }
107 return File{}, txErr
108 }
109
110 // Commit the transaction
111 if txErr = tx.Commit(); txErr != nil {
112 return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
113 }
114
115 file = s.fromDBItem(dbFile)
116 s.Publish(pubsub.CreatedEvent, file)
117 return file, nil
118 }
119
120 return file, err
121}
122
123func (s *service) Get(ctx context.Context, id string) (File, error) {
124 dbFile, err := s.q.GetFile(ctx, id)
125 if err != nil {
126 return File{}, err
127 }
128 return s.fromDBItem(dbFile), nil
129}
130
131func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
132 dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
133 Path: path,
134 SessionID: sessionID,
135 })
136 if err != nil {
137 return File{}, err
138 }
139 return s.fromDBItem(dbFile), nil
140}
141
142func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
143 dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
144 if err != nil {
145 return nil, err
146 }
147 files := make([]File, len(dbFiles))
148 for i, dbFile := range dbFiles {
149 files[i] = s.fromDBItem(dbFile)
150 }
151 return files, nil
152}
153
154func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
155 dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
156 if err != nil {
157 return nil, err
158 }
159 files := make([]File, len(dbFiles))
160 for i, dbFile := range dbFiles {
161 files[i] = s.fromDBItem(dbFile)
162 }
163 return files, nil
164}
165
166func (s *service) Delete(ctx context.Context, id string) error {
167 file, err := s.Get(ctx, id)
168 if err != nil {
169 return err
170 }
171 err = s.q.DeleteFile(ctx, id)
172 if err != nil {
173 return err
174 }
175 s.Publish(pubsub.DeletedEvent, file)
176 return nil
177}
178
179func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
180 files, err := s.ListBySession(ctx, sessionID)
181 if err != nil {
182 return err
183 }
184 for _, file := range files {
185 err = s.Delete(ctx, file.ID)
186 if err != nil {
187 return err
188 }
189 }
190 return nil
191}
192
193func (s *service) fromDBItem(item db.File) File {
194 return File{
195 ID: item.ID,
196 SessionID: item.SessionID,
197 Path: item.Path,
198 Content: item.Content,
199 Version: item.Version,
200 CreatedAt: item.CreatedAt,
201 UpdatedAt: item.UpdatedAt,
202 }
203}