1package sqlite
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"log"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/soft-serve/server/db"
 13	"github.com/charmbracelet/soft-serve/server/db/types"
 14	"modernc.org/sqlite"
 15	sqlitelib "modernc.org/sqlite/lib"
 16)
 17
 18var _ db.Store = &Sqlite{}
 19
 20// Sqlite is a SQLite database.
 21type Sqlite struct {
 22	path string
 23	db   *sql.DB
 24}
 25
 26// New creates a new DB in the given path.
 27func New(path string) (*Sqlite, error) {
 28	var err error
 29	log.Printf("Opening SQLite db: %s\n", path)
 30	db, err := sql.Open("sqlite", path+
 31		"?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)")
 32	if err != nil {
 33		return nil, err
 34	}
 35	d := &Sqlite{
 36		db:   db,
 37		path: path,
 38	}
 39	if err = d.CreateDB(); err != nil {
 40		return nil, fmt.Errorf("failed to create db: %w", err)
 41	}
 42	return d, d.db.Ping()
 43}
 44
 45// Close closes the database.
 46func (d *Sqlite) Close() error {
 47	return d.db.Close()
 48}
 49
 50// CreateDB creates the database and tables.
 51func (d *Sqlite) CreateDB() error {
 52	return d.wrapTransaction(func(tx *sql.Tx) error {
 53		if _, err := tx.Exec(sqlCreateUserTable); err != nil {
 54			return err
 55		}
 56		if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
 57			return err
 58		}
 59		if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
 60			return err
 61		}
 62		if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
 63			return err
 64		}
 65		return nil
 66	})
 67}
 68
 69// AddUser adds a new user.
 70func (d *Sqlite) AddUser(name, login, email, password string, isAdmin bool) error {
 71	var l *string
 72	var e *string
 73	var p *string
 74	if login != "" {
 75		login = strings.ToLower(login)
 76		l = &login
 77	}
 78	if email != "" {
 79		email = strings.ToLower(email)
 80		e = &email
 81	}
 82	if password != "" {
 83		p = &password
 84	}
 85	if err := d.wrapTransaction(func(tx *sql.Tx) error {
 86		if _, err := tx.Exec(sqlInsertUser, name, l, e, p, isAdmin); err != nil {
 87			return err
 88		}
 89		return nil
 90	}); err != nil {
 91		return err
 92	}
 93	return nil
 94}
 95
 96// DeleteUser deletes a user.
 97func (d *Sqlite) DeleteUser(id int) error {
 98	return d.wrapTransaction(func(tx *sql.Tx) error {
 99		_, err := tx.Exec(sqlDeleteUser, id)
100		return err
101	})
102}
103
104// GetUser returns a user by ID.
105func (d *Sqlite) GetUser(id int) (*types.User, error) {
106	var u types.User
107	if err := d.wrapTransaction(func(tx *sql.Tx) error {
108		r := tx.QueryRow(sqlSelectUser, id)
109		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
110			return err
111		}
112		return nil
113	}); err != nil {
114		return nil, err
115	}
116	return &u, nil
117}
118
119// GetUserByLogin returns a user by login.
120func (d *Sqlite) GetUserByLogin(login string) (*types.User, error) {
121	login = strings.ToLower(login)
122	var u types.User
123	if err := d.wrapTransaction(func(tx *sql.Tx) error {
124		r := tx.QueryRow(sqlSelectUserByLogin, login)
125		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
126			return err
127		}
128		return nil
129	}); err != nil {
130		return nil, err
131	}
132	return &u, nil
133}
134
135// GetUserByLogin returns a user by login.
136func (d *Sqlite) GetUserByEmail(email string) (*types.User, error) {
137	email = strings.ToLower(email)
138	var u types.User
139	if err := d.wrapTransaction(func(tx *sql.Tx) error {
140		r := tx.QueryRow(sqlSelectUserByEmail, email)
141		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
142			return err
143		}
144		return nil
145	}); err != nil {
146		return nil, err
147	}
148	return &u, nil
149}
150
151// GetUserByPublicKey returns a user by public key.
152func (d *Sqlite) GetUserByPublicKey(key string) (*types.User, error) {
153	var u types.User
154	if err := d.wrapTransaction(func(tx *sql.Tx) error {
155		r := tx.QueryRow(sqlSelectUserByPublicKey, key)
156		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
157			return err
158		}
159		return nil
160	}); err != nil {
161		return nil, err
162	}
163	return &u, nil
164}
165
166// SetUserName sets the user name.
167func (d *Sqlite) SetUserName(user *types.User, name string) error {
168	return d.wrapTransaction(func(tx *sql.Tx) error {
169		_, err := tx.Exec(sqlUpdateUserName, name, user.ID)
170		return err
171	})
172}
173
174// SetUserLogin sets the user login.
175func (d *Sqlite) SetUserLogin(user *types.User, login string) error {
176	if login == "" {
177		return fmt.Errorf("login cannot be empty")
178	}
179	login = strings.ToLower(login)
180	return d.wrapTransaction(func(tx *sql.Tx) error {
181		_, err := tx.Exec(sqlUpdateUserLogin, login, user.ID)
182		return err
183	})
184}
185
186// SetUserEmail sets the user email.
187func (d *Sqlite) SetUserEmail(user *types.User, email string) error {
188	if email == "" {
189		return fmt.Errorf("email cannot be empty")
190	}
191	email = strings.ToLower(email)
192	return d.wrapTransaction(func(tx *sql.Tx) error {
193		_, err := tx.Exec(sqlUpdateUserEmail, email, user.ID)
194		return err
195	})
196}
197
198// SetUserPassword sets the user password.
199func (d *Sqlite) SetUserPassword(user *types.User, password string) error {
200	if password == "" {
201		return fmt.Errorf("password cannot be empty")
202	}
203	return d.wrapTransaction(func(tx *sql.Tx) error {
204		_, err := tx.Exec(sqlUpdateUserPassword, password, user.ID)
205		return err
206	})
207}
208
209// SetUserAdmin sets the user admin.
210func (d *Sqlite) SetUserAdmin(user *types.User, admin bool) error {
211	return d.wrapTransaction(func(tx *sql.Tx) error {
212		_, err := tx.Exec(sqlUpdateUserAdmin, admin, user.ID)
213		return err
214	})
215}
216
217// CountUsers returns the number of users.
218func (d *Sqlite) CountUsers() (int, error) {
219	var count int
220	if err := d.wrapTransaction(func(tx *sql.Tx) error {
221		r := tx.QueryRow(sqlCountUsers)
222		if err := r.Scan(&count); err != nil {
223			return err
224		}
225		return nil
226	}); err != nil {
227		return 0, err
228	}
229	return count, nil
230}
231
232// AddUserPublicKey adds a new user public key.
233func (d *Sqlite) AddUserPublicKey(user *types.User, key string) error {
234	return d.wrapTransaction(func(tx *sql.Tx) error {
235		_, err := tx.Exec(sqlInsertPublicKey, user.ID, key)
236		return err
237	})
238}
239
240// DeleteUserPublicKey deletes a user public key.
241func (d *Sqlite) DeleteUserPublicKey(id int) error {
242	return d.wrapTransaction(func(tx *sql.Tx) error {
243		_, err := tx.Exec(sqlDeletePublicKey, id)
244		return err
245	})
246}
247
248// GetUserPublicKeys returns the user public keys.
249func (d *Sqlite) GetUserPublicKeys(user *types.User) ([]*types.PublicKey, error) {
250	keys := make([]*types.PublicKey, 0)
251	if err := d.wrapTransaction(func(tx *sql.Tx) error {
252		rows, err := tx.Query(sqlSelectUserPublicKeys, user.ID)
253		if err != nil {
254			return err
255		}
256		if err := rows.Err(); err != nil {
257			return err
258		}
259		defer rows.Close()
260		for rows.Next() {
261			var k types.PublicKey
262			if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
263				return err
264			}
265			keys = append(keys, &k)
266		}
267		return nil
268	}); err != nil {
269		return nil, err
270	}
271	return keys, nil
272}
273
274// AddRepo adds a new repo.
275func (d *Sqlite) AddRepo(name, projectName, description string, isPrivate bool) error {
276	name = strings.ToLower(name)
277	return d.wrapTransaction(func(tx *sql.Tx) error {
278		_, err := tx.Exec(sqlInsertRepo, name, projectName, description, isPrivate)
279		return err
280	})
281}
282
283// DeleteRepo deletes a repo.
284func (d *Sqlite) DeleteRepo(name string) error {
285	name = strings.ToLower(name)
286	return d.wrapTransaction(func(tx *sql.Tx) error {
287		_, err := tx.Exec(sqlDeleteRepoWithName, name)
288		return err
289	})
290}
291
292// GetRepo returns a repo by name.
293func (d *Sqlite) GetRepo(name string) (*types.Repo, error) {
294	name = strings.ToLower(name)
295	var r types.Repo
296	if err := d.wrapTransaction(func(tx *sql.Tx) error {
297		rows := tx.QueryRow(sqlSelectRepoByName, name)
298		if err := rows.Scan(&r.ID, &r.Name, &r.ProjectName, &r.Description, &r.Private, &r.CreatedAt, &r.UpdatedAt); err != nil {
299			return err
300		}
301		if err := rows.Err(); err != nil {
302			return err
303		}
304		return nil
305	}); err != nil {
306		return nil, err
307	}
308	return &r, nil
309}
310
311// SetRepoName sets the repo name.
312func (d *Sqlite) SetRepoName(name string, newName string) error {
313	name = strings.ToLower(name)
314	newName = strings.ToLower(newName)
315	return d.wrapTransaction(func(tx *sql.Tx) error {
316		_, err := tx.Exec(sqlUpdateRepoNameByName, newName, name)
317		return err
318	})
319}
320
321// SetRepoProjectName sets the repo project name.
322func (d *Sqlite) SetRepoProjectName(name string, projectName string) error {
323	name = strings.ToLower(name)
324	return d.wrapTransaction(func(tx *sql.Tx) error {
325		_, err := tx.Exec(sqlUpdateRepoProjectNameByName, projectName, name)
326		return err
327	})
328}
329
330// SetRepoDescription sets the repo description.
331func (d *Sqlite) SetRepoDescription(name string, description string) error {
332	name = strings.ToLower(name)
333	return d.wrapTransaction(func(tx *sql.Tx) error {
334		_, err := tx.Exec(sqlUpdateRepoDescriptionByName, description,
335			name)
336		return err
337	})
338}
339
340// SetRepoPrivate sets the repo private.
341func (d *Sqlite) SetRepoPrivate(name string, private bool) error {
342	name = strings.ToLower(name)
343	return d.wrapTransaction(func(tx *sql.Tx) error {
344		_, err := tx.Exec(sqlUpdateRepoPrivateByName, private, name)
345		return err
346	})
347}
348
349// AddRepoCollab adds a new repo collaborator.
350func (d *Sqlite) AddRepoCollab(repo string, user *types.User) error {
351	return d.wrapTransaction(func(tx *sql.Tx) error {
352		_, err := tx.Exec(sqlInsertCollabByName, repo, user.ID)
353		return err
354	})
355}
356
357// DeleteRepoCollab deletes a repo collaborator.
358func (d *Sqlite) DeleteRepoCollab(userID int, repoID int) error {
359	return d.wrapTransaction(func(tx *sql.Tx) error {
360		_, err := tx.Exec(sqlDeleteCollab, repoID, userID)
361		return err
362	})
363}
364
365// ListRepoCollabs returns a list of repo collaborators.
366func (d *Sqlite) ListRepoCollabs(repo string) ([]*types.User, error) {
367	collabs := make([]*types.User, 0)
368	if err := d.wrapTransaction(func(tx *sql.Tx) error {
369		rows, err := tx.Query(sqlSelectRepoCollabsByName, repo)
370		if err != nil {
371			return err
372		}
373		if err := rows.Err(); err != nil {
374			return err
375		}
376		defer rows.Close()
377		for rows.Next() {
378			var c types.User
379			if err := rows.Scan(&c.ID, &c.Name, &c.Login, &c.Email, &c.Admin, &c.CreatedAt, &c.UpdatedAt); err != nil {
380				return err
381			}
382			collabs = append(collabs, &c)
383		}
384		return nil
385	}); err != nil {
386		return nil, err
387	}
388	return collabs, nil
389}
390
391// ListRepoPublicKeys returns a list of repo public keys.
392func (d *Sqlite) ListRepoPublicKeys(repo string) ([]*types.PublicKey, error) {
393	keys := make([]*types.PublicKey, 0)
394	if err := d.wrapTransaction(func(tx *sql.Tx) error {
395		rows, err := tx.Query(sqlSelectRepoPublicKeysByName, repo)
396		if err != nil {
397			return err
398		}
399		if err := rows.Err(); err != nil {
400			return err
401		}
402		defer rows.Close()
403		for rows.Next() {
404			var k types.PublicKey
405			if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
406				return err
407			}
408			keys = append(keys, &k)
409		}
410		return nil
411	}); err != nil {
412		return nil, err
413	}
414	return keys, nil
415}
416
417// IsRepoPublicKeyCollab returns true if the public key is a collaborator for the repository.
418func (d *Sqlite) IsRepoPublicKeyCollab(repo string, key string) (bool, error) {
419	var count int
420	if err := d.wrapTransaction(func(tx *sql.Tx) error {
421		rows := tx.QueryRow(sqlSelectRepoPublicKeyCollabByName, repo, key)
422		if err := rows.Scan(&count); err != nil {
423			return err
424		}
425		if err := rows.Err(); err != nil {
426			return err
427		}
428		return nil
429	}); err != nil {
430		return false, err
431	}
432	return count > 0, nil
433}
434
435// WrapTransaction runs the given function within a transaction.
436func (d *Sqlite) wrapTransaction(f func(tx *sql.Tx) error) error {
437	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
438	defer cancel()
439	tx, err := d.db.BeginTx(ctx, nil)
440	if err != nil {
441		log.Printf("error starting transaction: %s", err)
442		return err
443	}
444	for {
445		err = f(tx)
446		if err != nil {
447			switch {
448			case errors.Is(err, sql.ErrNoRows):
449			default:
450				serr, ok := err.(*sqlite.Error)
451				if ok {
452					switch serr.Code() {
453					case sqlitelib.SQLITE_BUSY:
454						continue
455					}
456					log.Printf("error in transaction: %d: %s", serr.Code(), serr)
457				} else {
458					log.Printf("error in transaction: %s", err)
459				}
460			}
461			return err
462		}
463		err = tx.Commit()
464		if err != nil {
465			log.Printf("error committing transaction: %s", err)
466			return err
467		}
468		break
469	}
470	return nil
471}