sqlite.go

  1package sqlite
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"log"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/soft-serve/server/db"
 13	"github.com/charmbracelet/soft-serve/server/db/types"
 14	"modernc.org/sqlite"
 15	sqlitelib "modernc.org/sqlite/lib"
 16)
 17
 18var _ db.Store = &Sqlite{}
 19
 20// Sqlite is a SQLite database.
 21type Sqlite struct {
 22	path string
 23	db   *sql.DB
 24}
 25
 26// New creates a new DB in the given path.
 27func New(path string) (*Sqlite, error) {
 28	var err error
 29	log.Printf("Opening SQLite db: %s\n", path)
 30	db, err := sql.Open("sqlite", path+
 31		"?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)")
 32	if err != nil {
 33		return nil, err
 34	}
 35	d := &Sqlite{
 36		db:   db,
 37		path: path,
 38	}
 39	if err = d.CreateDB(); err != nil {
 40		return nil, fmt.Errorf("failed to create db: %w", err)
 41	}
 42	return d, d.db.Ping()
 43}
 44
 45// Close closes the database.
 46func (d *Sqlite) Close() error {
 47	return d.db.Close()
 48}
 49
 50// CreateDB creates the database and tables.
 51func (d *Sqlite) CreateDB() error {
 52	return d.wrapTransaction(func(tx *sql.Tx) error {
 53		if _, err := tx.Exec(sqlCreateUserTable); err != nil {
 54			return err
 55		}
 56		if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
 57			return err
 58		}
 59		if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
 60			return err
 61		}
 62		if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
 63			return err
 64		}
 65		return nil
 66	})
 67}
 68
 69// AddUser adds a new user.
 70func (d *Sqlite) AddUser(name, login, email, password string, isAdmin bool) error {
 71	var l *string
 72	var e *string
 73	var p *string
 74	if login != "" {
 75		login = strings.ToLower(login)
 76		l = &login
 77	}
 78	if email != "" {
 79		email = strings.ToLower(email)
 80		e = &email
 81	}
 82	if password != "" {
 83		p = &password
 84	}
 85	if err := d.wrapTransaction(func(tx *sql.Tx) error {
 86		if _, err := tx.Exec(sqlInsertUser, name, l, e, p, isAdmin); err != nil {
 87			return err
 88		}
 89		return nil
 90	}); err != nil {
 91		return err
 92	}
 93	return nil
 94}
 95
 96// DeleteUser deletes a user.
 97func (d *Sqlite) DeleteUser(id int) error {
 98	return d.wrapTransaction(func(tx *sql.Tx) error {
 99		_, err := tx.Exec(sqlDeleteUser, id)
100		return err
101	})
102}
103
104// GetUser returns a user by ID.
105func (d *Sqlite) GetUser(id int) (*types.User, error) {
106	var u types.User
107	if err := d.wrapTransaction(func(tx *sql.Tx) error {
108		r := tx.QueryRow(sqlSelectUser, id)
109		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
110			return err
111		}
112		return nil
113	}); err != nil {
114		return nil, err
115	}
116	return &u, nil
117}
118
119// GetUserByLogin returns a user by login.
120func (d *Sqlite) GetUserByLogin(login string) (*types.User, error) {
121	login = strings.ToLower(login)
122	var u types.User
123	if err := d.wrapTransaction(func(tx *sql.Tx) error {
124		r := tx.QueryRow(sqlSelectUserByLogin, login)
125		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
126			return err
127		}
128		return nil
129	}); err != nil {
130		return nil, err
131	}
132	return &u, nil
133}
134
135// GetUserByLogin returns a user by login.
136func (d *Sqlite) GetUserByEmail(email string) (*types.User, error) {
137	email = strings.ToLower(email)
138	var u types.User
139	if err := d.wrapTransaction(func(tx *sql.Tx) error {
140		r := tx.QueryRow(sqlSelectUserByEmail, email)
141		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
142			return err
143		}
144		return nil
145	}); err != nil {
146		return nil, err
147	}
148	return &u, nil
149}
150
151// GetUserByPublicKey returns a user by public key.
152func (d *Sqlite) GetUserByPublicKey(key string) (*types.User, error) {
153	var u types.User
154	if err := d.wrapTransaction(func(tx *sql.Tx) error {
155		r := tx.QueryRow(sqlSelectUserByPublicKey, key)
156		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
157			return err
158		}
159		return nil
160	}); err != nil {
161		return nil, err
162	}
163	return &u, nil
164}
165
166// SetUserName sets the user name.
167func (d *Sqlite) SetUserName(user *types.User, name string) error {
168	return d.wrapTransaction(func(tx *sql.Tx) error {
169		_, err := tx.Exec(sqlUpdateUserName, name, user.ID)
170		return err
171	})
172}
173
174// SetUserLogin sets the user login.
175func (d *Sqlite) SetUserLogin(user *types.User, login string) error {
176	if login == "" {
177		return fmt.Errorf("login cannot be empty")
178	}
179	login = strings.ToLower(login)
180	return d.wrapTransaction(func(tx *sql.Tx) error {
181		_, err := tx.Exec(sqlUpdateUserLogin, login, user.ID)
182		return err
183	})
184}
185
186// SetUserEmail sets the user email.
187func (d *Sqlite) SetUserEmail(user *types.User, email string) error {
188	if email == "" {
189		return fmt.Errorf("email cannot be empty")
190	}
191	email = strings.ToLower(email)
192	return d.wrapTransaction(func(tx *sql.Tx) error {
193		_, err := tx.Exec(sqlUpdateUserEmail, email, user.ID)
194		return err
195	})
196}
197
198// SetUserPassword sets the user password.
199func (d *Sqlite) SetUserPassword(user *types.User, password string) error {
200	if password == "" {
201		return fmt.Errorf("password cannot be empty")
202	}
203	return d.wrapTransaction(func(tx *sql.Tx) error {
204		_, err := tx.Exec(sqlUpdateUserPassword, password, user.ID)
205		return err
206	})
207}
208
209// SetUserAdmin sets the user admin.
210func (d *Sqlite) SetUserAdmin(user *types.User, admin bool) error {
211	return d.wrapTransaction(func(tx *sql.Tx) error {
212		_, err := tx.Exec(sqlUpdateUserAdmin, admin, user.ID)
213		return err
214	})
215}
216
217// CountUsers returns the number of users.
218func (d *Sqlite) CountUsers() (int, error) {
219	var count int
220	if err := d.wrapTransaction(func(tx *sql.Tx) error {
221		r := tx.QueryRow(sqlCountUsers)
222		if err := r.Scan(&count); err != nil {
223			return err
224		}
225		return nil
226	}); err != nil {
227		return 0, err
228	}
229	return count, nil
230}
231
232// AddUserPublicKey adds a new user public key.
233func (d *Sqlite) AddUserPublicKey(user *types.User, key string) error {
234	return d.wrapTransaction(func(tx *sql.Tx) error {
235		_, err := tx.Exec(sqlInsertPublicKey, user.ID, key)
236		return err
237	})
238}
239
240// DeleteUserPublicKey deletes a user public key.
241func (d *Sqlite) DeleteUserPublicKey(id int) error {
242	return d.wrapTransaction(func(tx *sql.Tx) error {
243		_, err := tx.Exec(sqlDeletePublicKey, id)
244		return err
245	})
246}
247
248// GetUserPublicKeys returns the user public keys.
249func (d *Sqlite) GetUserPublicKeys(user *types.User) ([]*types.PublicKey, error) {
250	keys := make([]*types.PublicKey, 0)
251	if err := d.wrapTransaction(func(tx *sql.Tx) error {
252		rows, err := tx.Query(sqlSelectUserPublicKeys, user.ID)
253		if err != nil {
254			return err
255		}
256		if err := rows.Err(); err != nil {
257			return err
258		}
259		defer rows.Close()
260		for rows.Next() {
261			var k types.PublicKey
262			if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
263				return err
264			}
265			keys = append(keys, &k)
266		}
267		return nil
268	}); err != nil {
269		return nil, err
270	}
271	return keys, nil
272}
273
274// AddRepo adds a new repo.
275func (d *Sqlite) AddRepo(name, projectName, description string, isPrivate bool) error {
276	name = strings.ToLower(name)
277	return d.wrapTransaction(func(tx *sql.Tx) error {
278		_, err := tx.Exec(sqlInsertRepo, name, projectName, description, isPrivate)
279		return err
280	})
281}
282
283// DeleteRepo deletes a repo.
284func (d *Sqlite) DeleteRepo(name string) error {
285	name = strings.ToLower(name)
286	return d.wrapTransaction(func(tx *sql.Tx) error {
287		_, err := tx.Exec(sqlDeleteRepoWithName, name)
288		return err
289	})
290}
291
292// GetRepo returns a repo by name.
293func (d *Sqlite) GetRepo(name string) (*types.Repo, error) {
294	name = strings.ToLower(name)
295	var r types.Repo
296	if err := d.wrapTransaction(func(tx *sql.Tx) error {
297		rows := tx.QueryRow(sqlSelectRepoByName, name)
298		if err := rows.Scan(&r.ID, &r.Name, &r.ProjectName, &r.Description, &r.Private, &r.CreatedAt, &r.UpdatedAt); err != nil {
299			return err
300		}
301		if err := rows.Err(); err != nil {
302			return err
303		}
304		return nil
305	}); err != nil {
306		return nil, err
307	}
308	return &r, nil
309}
310
311// SetRepoProjectName sets the repo project name.
312func (d *Sqlite) SetRepoProjectName(name string, projectName string) error {
313	name = strings.ToLower(name)
314	return d.wrapTransaction(func(tx *sql.Tx) error {
315		_, err := tx.Exec(sqlUpdateRepoProjectNameByName, projectName, name)
316		return err
317	})
318}
319
320// SetRepoDescription sets the repo description.
321func (d *Sqlite) SetRepoDescription(name string, description string) error {
322	name = strings.ToLower(name)
323	return d.wrapTransaction(func(tx *sql.Tx) error {
324		_, err := tx.Exec(sqlUpdateRepoDescriptionByName, description,
325			name)
326		return err
327	})
328}
329
330// SetRepoPrivate sets the repo private.
331func (d *Sqlite) SetRepoPrivate(name string, private bool) error {
332	name = strings.ToLower(name)
333	return d.wrapTransaction(func(tx *sql.Tx) error {
334		_, err := tx.Exec(sqlUpdateRepoPrivateByName, private, name)
335		return err
336	})
337}
338
339// AddRepoCollab adds a new repo collaborator.
340func (d *Sqlite) AddRepoCollab(repo string, user *types.User) error {
341	return d.wrapTransaction(func(tx *sql.Tx) error {
342		_, err := tx.Exec(sqlInsertCollabByName, repo, user.ID)
343		return err
344	})
345}
346
347// DeleteRepoCollab deletes a repo collaborator.
348func (d *Sqlite) DeleteRepoCollab(userID int, repoID int) error {
349	return d.wrapTransaction(func(tx *sql.Tx) error {
350		_, err := tx.Exec(sqlDeleteCollab, repoID, userID)
351		return err
352	})
353}
354
355// ListRepoCollabs returns a list of repo collaborators.
356func (d *Sqlite) ListRepoCollabs(repo string) ([]*types.User, error) {
357	collabs := make([]*types.User, 0)
358	if err := d.wrapTransaction(func(tx *sql.Tx) error {
359		rows, err := tx.Query(sqlSelectRepoCollabsByName, repo)
360		if err != nil {
361			return err
362		}
363		if err := rows.Err(); err != nil {
364			return err
365		}
366		defer rows.Close()
367		for rows.Next() {
368			var c types.User
369			if err := rows.Scan(&c.ID, &c.Name, &c.Login, &c.Email, &c.Admin, &c.CreatedAt, &c.UpdatedAt); err != nil {
370				return err
371			}
372			collabs = append(collabs, &c)
373		}
374		return nil
375	}); err != nil {
376		return nil, err
377	}
378	return collabs, nil
379}
380
381// ListRepoPublicKeys returns a list of repo public keys.
382func (d *Sqlite) ListRepoPublicKeys(repo string) ([]*types.PublicKey, error) {
383	keys := make([]*types.PublicKey, 0)
384	if err := d.wrapTransaction(func(tx *sql.Tx) error {
385		rows, err := tx.Query(sqlSelectRepoPublicKeysByName, repo)
386		if err != nil {
387			return err
388		}
389		if err := rows.Err(); err != nil {
390			return err
391		}
392		defer rows.Close()
393		for rows.Next() {
394			var k types.PublicKey
395			if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
396				return err
397			}
398			keys = append(keys, &k)
399		}
400		return nil
401	}); err != nil {
402		return nil, err
403	}
404	return keys, nil
405}
406
407// WrapTransaction runs the given function within a transaction.
408func (d *Sqlite) wrapTransaction(f func(tx *sql.Tx) error) error {
409	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
410	defer cancel()
411	tx, err := d.db.BeginTx(ctx, nil)
412	if err != nil {
413		log.Printf("error starting transaction: %s", err)
414		return err
415	}
416	for {
417		err = f(tx)
418		if err != nil {
419			switch {
420			case errors.Is(err, sql.ErrNoRows):
421			default:
422				serr, ok := err.(*sqlite.Error)
423				if ok {
424					switch serr.Code() {
425					case sqlitelib.SQLITE_BUSY:
426						continue
427					}
428					log.Printf("error in transaction: %d: %s", serr.Code(), serr)
429				} else {
430					log.Printf("error in transaction: %s", err)
431				}
432			}
433			return err
434		}
435		err = tx.Commit()
436		if err != nil {
437			log.Printf("error committing transaction: %s", err)
438			return err
439		}
440		break
441	}
442	return nil
443}