sqlite.go

  1package sqlite
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"errors"
  7	"fmt"
  8	"log"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/soft-serve/server/db"
 13	"github.com/charmbracelet/soft-serve/server/db/types"
 14	"modernc.org/sqlite"
 15	sqlitelib "modernc.org/sqlite/lib"
 16)
 17
 18var _ db.Store = &Sqlite{}
 19
 20// Sqlite is a SQLite database.
 21type Sqlite struct {
 22	path string
 23	db   *sql.DB
 24}
 25
 26// New creates a new DB in the given path.
 27func New(path string) (*Sqlite, error) {
 28	var err error
 29	log.Printf("Opening SQLite db: %s\n", path)
 30	db, err := sql.Open("sqlite", path+
 31		"?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)")
 32	if err != nil {
 33		return nil, err
 34	}
 35	d := &Sqlite{
 36		db:   db,
 37		path: path,
 38	}
 39	if err = d.CreateDB(); err != nil {
 40		return nil, fmt.Errorf("failed to create db: %w", err)
 41	}
 42	return d, d.db.Ping()
 43}
 44
 45// Close closes the database.
 46func (d *Sqlite) Close() error {
 47	return d.db.Close()
 48}
 49
 50// CreateDB creates the database and tables.
 51func (d *Sqlite) CreateDB() error {
 52	return d.wrapTransaction(func(tx *sql.Tx) error {
 53		if _, err := tx.Exec(sqlCreateConfigTable); err != nil {
 54			return err
 55		}
 56		if _, err := tx.Exec(sqlCreateUserTable); err != nil {
 57			return err
 58		}
 59		if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
 60			return err
 61		}
 62		if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
 63			return err
 64		}
 65		if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
 66			return err
 67		}
 68		return nil
 69	})
 70}
 71
 72const defaultConfigID = 1
 73
 74// GetConfig returns the server config.
 75func (d *Sqlite) GetConfig() (*types.Config, error) {
 76	var c types.Config
 77	if err := d.wrapTransaction(func(tx *sql.Tx) error {
 78		r := tx.QueryRow(sqlSelectConfig, defaultConfigID)
 79		if err := r.Scan(&c.ID, &c.Name, &c.Host, &c.Port, &c.AnonAccess, &c.AllowKeyless, &c.CreatedAt, &c.UpdatedAt); err != nil {
 80			return err
 81		}
 82		return nil
 83	}); err != nil {
 84		return nil, err
 85	}
 86	return &c, nil
 87}
 88
 89// SetConfigName sets the server config name.
 90func (d *Sqlite) SetConfigName(name string) error {
 91	return d.wrapTransaction(func(tx *sql.Tx) error {
 92		_, err := tx.Exec(sqlUpdateConfigName, name, defaultConfigID)
 93		return err
 94	})
 95}
 96
 97// SetConfigHost sets the server config host.
 98func (d *Sqlite) SetConfigHost(host string) error {
 99	return d.wrapTransaction(func(tx *sql.Tx) error {
100		_, err := tx.Exec(sqlUpdateConfigHost, host, defaultConfigID)
101		return err
102	})
103}
104
105// SetConfigPort sets the server config port.
106func (d *Sqlite) SetConfigPort(port int) error {
107	return d.wrapTransaction(func(tx *sql.Tx) error {
108		_, err := tx.Exec(sqlUpdateConfigPort, port, defaultConfigID)
109		return err
110	})
111}
112
113// SetConfigAnonAccess sets the server config anon access.
114func (d *Sqlite) SetConfigAnonAccess(access string) error {
115	return d.wrapTransaction(func(tx *sql.Tx) error {
116		_, err := tx.Exec(sqlUpdateConfigAnon, access, defaultConfigID)
117		return err
118	})
119}
120
121// SetConfigAllowKeyless sets the server config allow keyless.
122func (d *Sqlite) SetConfigAllowKeyless(allow bool) error {
123	return d.wrapTransaction(func(tx *sql.Tx) error {
124		_, err := tx.Exec(sqlUpdateConfigKeyless, allow, defaultConfigID)
125		return err
126	})
127}
128
129// AddUser adds a new user.
130func (d *Sqlite) AddUser(name, login, email, password string, isAdmin bool) error {
131	var l *string
132	var e *string
133	var p *string
134	if login != "" {
135		login = strings.ToLower(login)
136		l = &login
137	}
138	if email != "" {
139		email = strings.ToLower(email)
140		e = &email
141	}
142	if password != "" {
143		p = &password
144	}
145	if err := d.wrapTransaction(func(tx *sql.Tx) error {
146		if _, err := tx.Exec(sqlInsertUser, name, l, e, p, isAdmin); err != nil {
147			return err
148		}
149		return nil
150	}); err != nil {
151		return err
152	}
153	return nil
154}
155
156// DeleteUser deletes a user.
157func (d *Sqlite) DeleteUser(id int) error {
158	return d.wrapTransaction(func(tx *sql.Tx) error {
159		_, err := tx.Exec(sqlDeleteUser, id)
160		return err
161	})
162}
163
164// GetUser returns a user by ID.
165func (d *Sqlite) GetUser(id int) (*types.User, error) {
166	var u types.User
167	if err := d.wrapTransaction(func(tx *sql.Tx) error {
168		r := tx.QueryRow(sqlSelectUser, id)
169		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
170			return err
171		}
172		return nil
173	}); err != nil {
174		return nil, err
175	}
176	return &u, nil
177}
178
179// GetUserByLogin returns a user by login.
180func (d *Sqlite) GetUserByLogin(login string) (*types.User, error) {
181	login = strings.ToLower(login)
182	var u types.User
183	if err := d.wrapTransaction(func(tx *sql.Tx) error {
184		r := tx.QueryRow(sqlSelectUserByLogin, login)
185		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
186			return err
187		}
188		return nil
189	}); err != nil {
190		return nil, err
191	}
192	return &u, nil
193}
194
195// GetUserByLogin returns a user by login.
196func (d *Sqlite) GetUserByEmail(email string) (*types.User, error) {
197	email = strings.ToLower(email)
198	var u types.User
199	if err := d.wrapTransaction(func(tx *sql.Tx) error {
200		r := tx.QueryRow(sqlSelectUserByEmail, email)
201		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
202			return err
203		}
204		return nil
205	}); err != nil {
206		return nil, err
207	}
208	return &u, nil
209}
210
211// GetUserByPublicKey returns a user by public key.
212func (d *Sqlite) GetUserByPublicKey(key string) (*types.User, error) {
213	var u types.User
214	if err := d.wrapTransaction(func(tx *sql.Tx) error {
215		r := tx.QueryRow(sqlSelectUserByPublicKey, key)
216		if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
217			return err
218		}
219		return nil
220	}); err != nil {
221		return nil, err
222	}
223	return &u, nil
224}
225
226// SetUserName sets the user name.
227func (d *Sqlite) SetUserName(user *types.User, name string) error {
228	return d.wrapTransaction(func(tx *sql.Tx) error {
229		_, err := tx.Exec(sqlUpdateUserName, name, user.ID)
230		return err
231	})
232}
233
234// SetUserLogin sets the user login.
235func (d *Sqlite) SetUserLogin(user *types.User, login string) error {
236	if login == "" {
237		return fmt.Errorf("login cannot be empty")
238	}
239	login = strings.ToLower(login)
240	return d.wrapTransaction(func(tx *sql.Tx) error {
241		_, err := tx.Exec(sqlUpdateUserLogin, login, user.ID)
242		return err
243	})
244}
245
246// SetUserEmail sets the user email.
247func (d *Sqlite) SetUserEmail(user *types.User, email string) error {
248	if email == "" {
249		return fmt.Errorf("email cannot be empty")
250	}
251	email = strings.ToLower(email)
252	return d.wrapTransaction(func(tx *sql.Tx) error {
253		_, err := tx.Exec(sqlUpdateUserEmail, email, user.ID)
254		return err
255	})
256}
257
258// SetUserPassword sets the user password.
259func (d *Sqlite) SetUserPassword(user *types.User, password string) error {
260	if password == "" {
261		return fmt.Errorf("password cannot be empty")
262	}
263	return d.wrapTransaction(func(tx *sql.Tx) error {
264		_, err := tx.Exec(sqlUpdateUserPassword, password, user.ID)
265		return err
266	})
267}
268
269// SetUserAdmin sets the user admin.
270func (d *Sqlite) SetUserAdmin(user *types.User, admin bool) error {
271	return d.wrapTransaction(func(tx *sql.Tx) error {
272		_, err := tx.Exec(sqlUpdateUserAdmin, admin, user.ID)
273		return err
274	})
275}
276
277// CountUsers returns the number of users.
278func (d *Sqlite) CountUsers() (int, error) {
279	var count int
280	if err := d.wrapTransaction(func(tx *sql.Tx) error {
281		r := tx.QueryRow(sqlCountUsers)
282		if err := r.Scan(&count); err != nil {
283			return err
284		}
285		return nil
286	}); err != nil {
287		return 0, err
288	}
289	return count, nil
290}
291
292// AddUserPublicKey adds a new user public key.
293func (d *Sqlite) AddUserPublicKey(user *types.User, key string) error {
294	return d.wrapTransaction(func(tx *sql.Tx) error {
295		_, err := tx.Exec(sqlInsertPublicKey, user.ID, key)
296		return err
297	})
298}
299
300// DeleteUserPublicKey deletes a user public key.
301func (d *Sqlite) DeleteUserPublicKey(id int) error {
302	return d.wrapTransaction(func(tx *sql.Tx) error {
303		_, err := tx.Exec(sqlDeletePublicKey, id)
304		return err
305	})
306}
307
308// GetUserPublicKeys returns the user public keys.
309func (d *Sqlite) GetUserPublicKeys(user *types.User) ([]*types.PublicKey, error) {
310	keys := make([]*types.PublicKey, 0)
311	if err := d.wrapTransaction(func(tx *sql.Tx) error {
312		rows, err := tx.Query(sqlSelectUserPublicKeys, user.ID)
313		if err != nil {
314			return err
315		}
316		if err := rows.Err(); err != nil {
317			return err
318		}
319		defer rows.Close()
320		for rows.Next() {
321			var k types.PublicKey
322			if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
323				return err
324			}
325			keys = append(keys, &k)
326		}
327		return nil
328	}); err != nil {
329		return nil, err
330	}
331	return keys, nil
332}
333
334// AddRepo adds a new repo.
335func (d *Sqlite) AddRepo(name, projectName, description string, isPrivate bool) error {
336	name = strings.ToLower(name)
337	return d.wrapTransaction(func(tx *sql.Tx) error {
338		_, err := tx.Exec(sqlInsertRepo, name, projectName, description, isPrivate)
339		return err
340	})
341}
342
343// DeleteRepo deletes a repo.
344func (d *Sqlite) DeleteRepo(name string) error {
345	name = strings.ToLower(name)
346	return d.wrapTransaction(func(tx *sql.Tx) error {
347		_, err := tx.Exec(sqlDeleteRepoWithName, name)
348		return err
349	})
350}
351
352// GetRepo returns a repo by name.
353func (d *Sqlite) GetRepo(name string) (*types.Repo, error) {
354	name = strings.ToLower(name)
355	var r types.Repo
356	if err := d.wrapTransaction(func(tx *sql.Tx) error {
357		rows := tx.QueryRow(sqlSelectRepoByName, name)
358		if err := rows.Scan(&r.ID, &r.Name, &r.ProjectName, &r.Description, &r.Private, &r.CreatedAt, &r.UpdatedAt); err != nil {
359			return err
360		}
361		if err := rows.Err(); err != nil {
362			return err
363		}
364		return nil
365	}); err != nil {
366		return nil, err
367	}
368	return &r, nil
369}
370
371// SetRepoProjectName sets the repo project name.
372func (d *Sqlite) SetRepoProjectName(name string, projectName string) error {
373	name = strings.ToLower(name)
374	return d.wrapTransaction(func(tx *sql.Tx) error {
375		_, err := tx.Exec(sqlUpdateRepoProjectNameByName, projectName, name)
376		return err
377	})
378}
379
380// SetRepoDescription sets the repo description.
381func (d *Sqlite) SetRepoDescription(name string, description string) error {
382	name = strings.ToLower(name)
383	return d.wrapTransaction(func(tx *sql.Tx) error {
384		_, err := tx.Exec(sqlUpdateRepoDescriptionByName, description,
385			name)
386		return err
387	})
388}
389
390// SetRepoPrivate sets the repo private.
391func (d *Sqlite) SetRepoPrivate(name string, private bool) error {
392	name = strings.ToLower(name)
393	return d.wrapTransaction(func(tx *sql.Tx) error {
394		_, err := tx.Exec(sqlUpdateRepoPrivateByName, private, name)
395		return err
396	})
397}
398
399// AddRepoCollab adds a new repo collaborator.
400func (d *Sqlite) AddRepoCollab(repo string, user *types.User) error {
401	return d.wrapTransaction(func(tx *sql.Tx) error {
402		_, err := tx.Exec(sqlInsertCollabByName, repo, user.ID)
403		return err
404	})
405}
406
407// DeleteRepoCollab deletes a repo collaborator.
408func (d *Sqlite) DeleteRepoCollab(userID int, repoID int) error {
409	return d.wrapTransaction(func(tx *sql.Tx) error {
410		_, err := tx.Exec(sqlDeleteCollab, repoID, userID)
411		return err
412	})
413}
414
415// ListRepoCollabs returns a list of repo collaborators.
416func (d *Sqlite) ListRepoCollabs(repo string) ([]*types.User, error) {
417	collabs := make([]*types.User, 0)
418	if err := d.wrapTransaction(func(tx *sql.Tx) error {
419		rows, err := tx.Query(sqlSelectRepoCollabsByName, repo)
420		if err != nil {
421			return err
422		}
423		if err := rows.Err(); err != nil {
424			return err
425		}
426		defer rows.Close()
427		for rows.Next() {
428			var c types.User
429			if err := rows.Scan(&c.ID, &c.Name, &c.Login, &c.Email, &c.Admin, &c.CreatedAt, &c.UpdatedAt); err != nil {
430				return err
431			}
432			collabs = append(collabs, &c)
433		}
434		return nil
435	}); err != nil {
436		return nil, err
437	}
438	return collabs, nil
439}
440
441// ListRepoPublicKeys returns a list of repo public keys.
442func (d *Sqlite) ListRepoPublicKeys(repo string) ([]*types.PublicKey, error) {
443	keys := make([]*types.PublicKey, 0)
444	if err := d.wrapTransaction(func(tx *sql.Tx) error {
445		rows, err := tx.Query(sqlSelectRepoPublicKeysByName, repo)
446		if err != nil {
447			return err
448		}
449		if err := rows.Err(); err != nil {
450			return err
451		}
452		defer rows.Close()
453		for rows.Next() {
454			var k types.PublicKey
455			if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
456				return err
457			}
458			keys = append(keys, &k)
459		}
460		return nil
461	}); err != nil {
462		return nil, err
463	}
464	return keys, nil
465}
466
467// WrapTransaction runs the given function within a transaction.
468func (d *Sqlite) wrapTransaction(f func(tx *sql.Tx) error) error {
469	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
470	defer cancel()
471	tx, err := d.db.BeginTx(ctx, nil)
472	if err != nil {
473		log.Printf("error starting transaction: %s", err)
474		return err
475	}
476	for {
477		err = f(tx)
478		if err != nil && !errors.Is(err, sql.ErrNoRows) {
479			serr, ok := err.(*sqlite.Error)
480			if ok {
481				switch serr.Code() {
482				case sqlitelib.SQLITE_BUSY:
483					continue
484				}
485				log.Printf("error in transaction: %d: %s", serr.Code(), serr)
486			} else {
487				log.Printf("error in transaction: %s", err)
488			}
489			return err
490		}
491		err = tx.Commit()
492		if err != nil {
493			log.Printf("error committing transaction: %s", err)
494			return err
495		}
496		break
497	}
498	return nil
499}