file.go

  1package history
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"fmt"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/db"
 10	"github.com/charmbracelet/crush/internal/proto"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15const (
 16	InitialVersion = 0
 17)
 18
 19type File = proto.File
 20
 21type Service interface {
 22	pubsub.Suscriber[File]
 23	Create(ctx context.Context, sessionID, path, content string) (File, error)
 24	CreateVersion(ctx context.Context, sessionID, path, content string) (File, error)
 25	Get(ctx context.Context, id string) (File, error)
 26	GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error)
 27	ListBySession(ctx context.Context, sessionID string) ([]File, error)
 28	ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error)
 29	Delete(ctx context.Context, id string) error
 30	DeleteSessionFiles(ctx context.Context, sessionID string) error
 31}
 32
 33type service struct {
 34	*pubsub.Broker[File]
 35	db *sql.DB
 36	q  *db.Queries
 37}
 38
 39func NewService(q *db.Queries, db *sql.DB) Service {
 40	return &service{
 41		Broker: pubsub.NewBroker[File](),
 42		q:      q,
 43		db:     db,
 44	}
 45}
 46
 47func (s *service) Create(ctx context.Context, sessionID, path, content string) (File, error) {
 48	return s.createWithVersion(ctx, sessionID, path, content, InitialVersion)
 49}
 50
 51func (s *service) CreateVersion(ctx context.Context, sessionID, path, content string) (File, error) {
 52	// Get the latest version for this path
 53	files, err := s.q.ListFilesByPath(ctx, path)
 54	if err != nil {
 55		return File{}, err
 56	}
 57
 58	if len(files) == 0 {
 59		// No previous versions, create initial
 60		return s.Create(ctx, sessionID, path, content)
 61	}
 62
 63	// Get the latest version
 64	latestFile := files[0] // Files are ordered by version DESC, created_at DESC
 65	nextVersion := latestFile.Version + 1
 66
 67	return s.createWithVersion(ctx, sessionID, path, content, nextVersion)
 68}
 69
 70func (s *service) createWithVersion(ctx context.Context, sessionID, path, content string, version int64) (File, error) {
 71	// Maximum number of retries for transaction conflicts
 72	const maxRetries = 3
 73	var file File
 74	var err error
 75
 76	// Retry loop for transaction conflicts
 77	for attempt := range maxRetries {
 78		// Start a transaction
 79		tx, txErr := s.db.BeginTx(ctx, nil)
 80		if txErr != nil {
 81			return File{}, fmt.Errorf("failed to begin transaction: %w", txErr)
 82		}
 83
 84		// Create a new queries instance with the transaction
 85		qtx := s.q.WithTx(tx)
 86
 87		// Try to create the file within the transaction
 88		dbFile, txErr := qtx.CreateFile(ctx, db.CreateFileParams{
 89			ID:        uuid.New().String(),
 90			SessionID: sessionID,
 91			Path:      path,
 92			Content:   content,
 93			Version:   version,
 94		})
 95		if txErr != nil {
 96			// Rollback the transaction
 97			tx.Rollback()
 98
 99			// Check if this is a uniqueness constraint violation
100			if strings.Contains(txErr.Error(), "UNIQUE constraint failed") {
101				if attempt < maxRetries-1 {
102					// If we have retries left, increment version and try again
103					version++
104					continue
105				}
106			}
107			return File{}, txErr
108		}
109
110		// Commit the transaction
111		if txErr = tx.Commit(); txErr != nil {
112			return File{}, fmt.Errorf("failed to commit transaction: %w", txErr)
113		}
114
115		file = s.fromDBItem(dbFile)
116		s.Publish(pubsub.CreatedEvent, file)
117		return file, nil
118	}
119
120	return file, err
121}
122
123func (s *service) Get(ctx context.Context, id string) (File, error) {
124	dbFile, err := s.q.GetFile(ctx, id)
125	if err != nil {
126		return File{}, err
127	}
128	return s.fromDBItem(dbFile), nil
129}
130
131func (s *service) GetByPathAndSession(ctx context.Context, path, sessionID string) (File, error) {
132	dbFile, err := s.q.GetFileByPathAndSession(ctx, db.GetFileByPathAndSessionParams{
133		Path:      path,
134		SessionID: sessionID,
135	})
136	if err != nil {
137		return File{}, err
138	}
139	return s.fromDBItem(dbFile), nil
140}
141
142func (s *service) ListBySession(ctx context.Context, sessionID string) ([]File, error) {
143	dbFiles, err := s.q.ListFilesBySession(ctx, sessionID)
144	if err != nil {
145		return nil, err
146	}
147	files := make([]File, len(dbFiles))
148	for i, dbFile := range dbFiles {
149		files[i] = s.fromDBItem(dbFile)
150	}
151	return files, nil
152}
153
154func (s *service) ListLatestSessionFiles(ctx context.Context, sessionID string) ([]File, error) {
155	dbFiles, err := s.q.ListLatestSessionFiles(ctx, sessionID)
156	if err != nil {
157		return nil, err
158	}
159	files := make([]File, len(dbFiles))
160	for i, dbFile := range dbFiles {
161		files[i] = s.fromDBItem(dbFile)
162	}
163	return files, nil
164}
165
166func (s *service) Delete(ctx context.Context, id string) error {
167	file, err := s.Get(ctx, id)
168	if err != nil {
169		return err
170	}
171	err = s.q.DeleteFile(ctx, id)
172	if err != nil {
173		return err
174	}
175	s.Publish(pubsub.DeletedEvent, file)
176	return nil
177}
178
179func (s *service) DeleteSessionFiles(ctx context.Context, sessionID string) error {
180	files, err := s.ListBySession(ctx, sessionID)
181	if err != nil {
182		return err
183	}
184	for _, file := range files {
185		err = s.Delete(ctx, file.ID)
186		if err != nil {
187			return err
188		}
189	}
190	return nil
191}
192
193func (s *service) fromDBItem(item db.File) File {
194	return File{
195		ID:        item.ID,
196		SessionID: item.SessionID,
197		Path:      item.Path,
198		Content:   item.Content,
199		Version:   item.Version,
200		CreatedAt: item.CreatedAt,
201		UpdatedAt: item.UpdatedAt,
202	}
203}