1package sqlite
2
3import (
4 "context"
5 "database/sql"
6 "errors"
7 "fmt"
8 "log"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/soft-serve/server/db"
13 "github.com/charmbracelet/soft-serve/server/db/types"
14 "modernc.org/sqlite"
15 sqlitelib "modernc.org/sqlite/lib"
16)
17
18var _ db.Store = &Sqlite{}
19
20// Sqlite is a SQLite database.
21type Sqlite struct {
22 path string
23 db *sql.DB
24}
25
26// New creates a new DB in the given path.
27func New(path string) (*Sqlite, error) {
28 var err error
29 log.Printf("Opening SQLite db: %s\n", path)
30 db, err := sql.Open("sqlite", path+
31 "?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)")
32 if err != nil {
33 return nil, err
34 }
35 d := &Sqlite{
36 db: db,
37 path: path,
38 }
39 if err = d.CreateDB(); err != nil {
40 return nil, fmt.Errorf("failed to create db: %w", err)
41 }
42 return d, d.db.Ping()
43}
44
45// Close closes the database.
46func (d *Sqlite) Close() error {
47 return d.db.Close()
48}
49
50// CreateDB creates the database and tables.
51func (d *Sqlite) CreateDB() error {
52 return d.wrapTransaction(func(tx *sql.Tx) error {
53 if _, err := tx.Exec(sqlCreateUserTable); err != nil {
54 return err
55 }
56 if _, err := tx.Exec(sqlCreatePublicKeyTable); err != nil {
57 return err
58 }
59 if _, err := tx.Exec(sqlCreateRepoTable); err != nil {
60 return err
61 }
62 if _, err := tx.Exec(sqlCreateCollabTable); err != nil {
63 return err
64 }
65 return nil
66 })
67}
68
69// AddUser adds a new user.
70func (d *Sqlite) AddUser(name, login, email, password string, isAdmin bool) error {
71 var l *string
72 var e *string
73 var p *string
74 if login != "" {
75 login = strings.ToLower(login)
76 l = &login
77 }
78 if email != "" {
79 email = strings.ToLower(email)
80 e = &email
81 }
82 if password != "" {
83 p = &password
84 }
85 if err := d.wrapTransaction(func(tx *sql.Tx) error {
86 if _, err := tx.Exec(sqlInsertUser, name, l, e, p, isAdmin); err != nil {
87 return err
88 }
89 return nil
90 }); err != nil {
91 return err
92 }
93 return nil
94}
95
96// DeleteUser deletes a user.
97func (d *Sqlite) DeleteUser(id int) error {
98 return d.wrapTransaction(func(tx *sql.Tx) error {
99 _, err := tx.Exec(sqlDeleteUser, id)
100 return err
101 })
102}
103
104// GetUser returns a user by ID.
105func (d *Sqlite) GetUser(id int) (*types.User, error) {
106 var u types.User
107 if err := d.wrapTransaction(func(tx *sql.Tx) error {
108 r := tx.QueryRow(sqlSelectUser, id)
109 if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
110 return err
111 }
112 return nil
113 }); err != nil {
114 return nil, err
115 }
116 return &u, nil
117}
118
119// GetUserByLogin returns a user by login.
120func (d *Sqlite) GetUserByLogin(login string) (*types.User, error) {
121 login = strings.ToLower(login)
122 var u types.User
123 if err := d.wrapTransaction(func(tx *sql.Tx) error {
124 r := tx.QueryRow(sqlSelectUserByLogin, login)
125 if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
126 return err
127 }
128 return nil
129 }); err != nil {
130 return nil, err
131 }
132 return &u, nil
133}
134
135// GetUserByLogin returns a user by login.
136func (d *Sqlite) GetUserByEmail(email string) (*types.User, error) {
137 email = strings.ToLower(email)
138 var u types.User
139 if err := d.wrapTransaction(func(tx *sql.Tx) error {
140 r := tx.QueryRow(sqlSelectUserByEmail, email)
141 if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
142 return err
143 }
144 return nil
145 }); err != nil {
146 return nil, err
147 }
148 return &u, nil
149}
150
151// GetUserByPublicKey returns a user by public key.
152func (d *Sqlite) GetUserByPublicKey(key string) (*types.User, error) {
153 var u types.User
154 if err := d.wrapTransaction(func(tx *sql.Tx) error {
155 r := tx.QueryRow(sqlSelectUserByPublicKey, key)
156 if err := r.Scan(&u.ID, &u.Name, &u.Login, &u.Email, &u.Password, &u.Admin, &u.CreatedAt, &u.UpdatedAt); err != nil {
157 return err
158 }
159 return nil
160 }); err != nil {
161 return nil, err
162 }
163 return &u, nil
164}
165
166// SetUserName sets the user name.
167func (d *Sqlite) SetUserName(user *types.User, name string) error {
168 return d.wrapTransaction(func(tx *sql.Tx) error {
169 _, err := tx.Exec(sqlUpdateUserName, name, user.ID)
170 return err
171 })
172}
173
174// SetUserLogin sets the user login.
175func (d *Sqlite) SetUserLogin(user *types.User, login string) error {
176 if login == "" {
177 return fmt.Errorf("login cannot be empty")
178 }
179 login = strings.ToLower(login)
180 return d.wrapTransaction(func(tx *sql.Tx) error {
181 _, err := tx.Exec(sqlUpdateUserLogin, login, user.ID)
182 return err
183 })
184}
185
186// SetUserEmail sets the user email.
187func (d *Sqlite) SetUserEmail(user *types.User, email string) error {
188 if email == "" {
189 return fmt.Errorf("email cannot be empty")
190 }
191 email = strings.ToLower(email)
192 return d.wrapTransaction(func(tx *sql.Tx) error {
193 _, err := tx.Exec(sqlUpdateUserEmail, email, user.ID)
194 return err
195 })
196}
197
198// SetUserPassword sets the user password.
199func (d *Sqlite) SetUserPassword(user *types.User, password string) error {
200 if password == "" {
201 return fmt.Errorf("password cannot be empty")
202 }
203 return d.wrapTransaction(func(tx *sql.Tx) error {
204 _, err := tx.Exec(sqlUpdateUserPassword, password, user.ID)
205 return err
206 })
207}
208
209// SetUserAdmin sets the user admin.
210func (d *Sqlite) SetUserAdmin(user *types.User, admin bool) error {
211 return d.wrapTransaction(func(tx *sql.Tx) error {
212 _, err := tx.Exec(sqlUpdateUserAdmin, admin, user.ID)
213 return err
214 })
215}
216
217// CountUsers returns the number of users.
218func (d *Sqlite) CountUsers() (int, error) {
219 var count int
220 if err := d.wrapTransaction(func(tx *sql.Tx) error {
221 r := tx.QueryRow(sqlCountUsers)
222 if err := r.Scan(&count); err != nil {
223 return err
224 }
225 return nil
226 }); err != nil {
227 return 0, err
228 }
229 return count, nil
230}
231
232// AddUserPublicKey adds a new user public key.
233func (d *Sqlite) AddUserPublicKey(user *types.User, key string) error {
234 return d.wrapTransaction(func(tx *sql.Tx) error {
235 _, err := tx.Exec(sqlInsertPublicKey, user.ID, key)
236 return err
237 })
238}
239
240// DeleteUserPublicKey deletes a user public key.
241func (d *Sqlite) DeleteUserPublicKey(id int) error {
242 return d.wrapTransaction(func(tx *sql.Tx) error {
243 _, err := tx.Exec(sqlDeletePublicKey, id)
244 return err
245 })
246}
247
248// GetUserPublicKeys returns the user public keys.
249func (d *Sqlite) GetUserPublicKeys(user *types.User) ([]*types.PublicKey, error) {
250 keys := make([]*types.PublicKey, 0)
251 if err := d.wrapTransaction(func(tx *sql.Tx) error {
252 rows, err := tx.Query(sqlSelectUserPublicKeys, user.ID)
253 if err != nil {
254 return err
255 }
256 if err := rows.Err(); err != nil {
257 return err
258 }
259 defer rows.Close()
260 for rows.Next() {
261 var k types.PublicKey
262 if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
263 return err
264 }
265 keys = append(keys, &k)
266 }
267 return nil
268 }); err != nil {
269 return nil, err
270 }
271 return keys, nil
272}
273
274// AddRepo adds a new repo.
275func (d *Sqlite) AddRepo(name, projectName, description string, isPrivate bool) error {
276 name = strings.ToLower(name)
277 return d.wrapTransaction(func(tx *sql.Tx) error {
278 _, err := tx.Exec(sqlInsertRepo, name, projectName, description, isPrivate)
279 return err
280 })
281}
282
283// DeleteRepo deletes a repo.
284func (d *Sqlite) DeleteRepo(name string) error {
285 name = strings.ToLower(name)
286 return d.wrapTransaction(func(tx *sql.Tx) error {
287 _, err := tx.Exec(sqlDeleteRepoWithName, name)
288 return err
289 })
290}
291
292// GetRepo returns a repo by name.
293func (d *Sqlite) GetRepo(name string) (*types.Repo, error) {
294 name = strings.ToLower(name)
295 var r types.Repo
296 if err := d.wrapTransaction(func(tx *sql.Tx) error {
297 rows := tx.QueryRow(sqlSelectRepoByName, name)
298 if err := rows.Scan(&r.ID, &r.Name, &r.ProjectName, &r.Description, &r.Private, &r.CreatedAt, &r.UpdatedAt); err != nil {
299 return err
300 }
301 if err := rows.Err(); err != nil {
302 return err
303 }
304 return nil
305 }); err != nil {
306 return nil, err
307 }
308 return &r, nil
309}
310
311// SetRepoProjectName sets the repo project name.
312func (d *Sqlite) SetRepoProjectName(name string, projectName string) error {
313 name = strings.ToLower(name)
314 return d.wrapTransaction(func(tx *sql.Tx) error {
315 _, err := tx.Exec(sqlUpdateRepoProjectNameByName, projectName, name)
316 return err
317 })
318}
319
320// SetRepoDescription sets the repo description.
321func (d *Sqlite) SetRepoDescription(name string, description string) error {
322 name = strings.ToLower(name)
323 return d.wrapTransaction(func(tx *sql.Tx) error {
324 _, err := tx.Exec(sqlUpdateRepoDescriptionByName, description,
325 name)
326 return err
327 })
328}
329
330// SetRepoPrivate sets the repo private.
331func (d *Sqlite) SetRepoPrivate(name string, private bool) error {
332 name = strings.ToLower(name)
333 return d.wrapTransaction(func(tx *sql.Tx) error {
334 _, err := tx.Exec(sqlUpdateRepoPrivateByName, private, name)
335 return err
336 })
337}
338
339// AddRepoCollab adds a new repo collaborator.
340func (d *Sqlite) AddRepoCollab(repo string, user *types.User) error {
341 return d.wrapTransaction(func(tx *sql.Tx) error {
342 _, err := tx.Exec(sqlInsertCollabByName, repo, user.ID)
343 return err
344 })
345}
346
347// DeleteRepoCollab deletes a repo collaborator.
348func (d *Sqlite) DeleteRepoCollab(userID int, repoID int) error {
349 return d.wrapTransaction(func(tx *sql.Tx) error {
350 _, err := tx.Exec(sqlDeleteCollab, repoID, userID)
351 return err
352 })
353}
354
355// ListRepoCollabs returns a list of repo collaborators.
356func (d *Sqlite) ListRepoCollabs(repo string) ([]*types.User, error) {
357 collabs := make([]*types.User, 0)
358 if err := d.wrapTransaction(func(tx *sql.Tx) error {
359 rows, err := tx.Query(sqlSelectRepoCollabsByName, repo)
360 if err != nil {
361 return err
362 }
363 if err := rows.Err(); err != nil {
364 return err
365 }
366 defer rows.Close()
367 for rows.Next() {
368 var c types.User
369 if err := rows.Scan(&c.ID, &c.Name, &c.Login, &c.Email, &c.Admin, &c.CreatedAt, &c.UpdatedAt); err != nil {
370 return err
371 }
372 collabs = append(collabs, &c)
373 }
374 return nil
375 }); err != nil {
376 return nil, err
377 }
378 return collabs, nil
379}
380
381// ListRepoPublicKeys returns a list of repo public keys.
382func (d *Sqlite) ListRepoPublicKeys(repo string) ([]*types.PublicKey, error) {
383 keys := make([]*types.PublicKey, 0)
384 if err := d.wrapTransaction(func(tx *sql.Tx) error {
385 rows, err := tx.Query(sqlSelectRepoPublicKeysByName, repo)
386 if err != nil {
387 return err
388 }
389 if err := rows.Err(); err != nil {
390 return err
391 }
392 defer rows.Close()
393 for rows.Next() {
394 var k types.PublicKey
395 if err := rows.Scan(&k.ID, &k.UserID, &k.PublicKey, &k.CreatedAt, &k.UpdatedAt); err != nil {
396 return err
397 }
398 keys = append(keys, &k)
399 }
400 return nil
401 }); err != nil {
402 return nil, err
403 }
404 return keys, nil
405}
406
407// IsRepoPublicKeyCollab returns true if the public key is a collaborator for the repository.
408func (d *Sqlite) IsRepoPublicKeyCollab(repo string, key string) (bool, error) {
409 var count int
410 if err := d.wrapTransaction(func(tx *sql.Tx) error {
411 rows := tx.QueryRow(sqlSelectRepoPublicKeyCollabByName, repo, key)
412 if err := rows.Scan(&count); err != nil {
413 return err
414 }
415 if err := rows.Err(); err != nil {
416 return err
417 }
418 return nil
419 }); err != nil {
420 return false, err
421 }
422 return count > 0, nil
423}
424
425// WrapTransaction runs the given function within a transaction.
426func (d *Sqlite) wrapTransaction(f func(tx *sql.Tx) error) error {
427 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
428 defer cancel()
429 tx, err := d.db.BeginTx(ctx, nil)
430 if err != nil {
431 log.Printf("error starting transaction: %s", err)
432 return err
433 }
434 for {
435 err = f(tx)
436 if err != nil {
437 switch {
438 case errors.Is(err, sql.ErrNoRows):
439 default:
440 serr, ok := err.(*sqlite.Error)
441 if ok {
442 switch serr.Code() {
443 case sqlitelib.SQLITE_BUSY:
444 continue
445 }
446 log.Printf("error in transaction: %d: %s", serr.Code(), serr)
447 } else {
448 log.Printf("error in transaction: %s", err)
449 }
450 }
451 return err
452 }
453 err = tx.Commit()
454 if err != nil {
455 log.Printf("error committing transaction: %s", err)
456 return err
457 }
458 break
459 }
460 return nil
461}