1package sqlite
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/server/backend"
  8	"github.com/charmbracelet/soft-serve/server/utils"
  9	"github.com/jmoiron/sqlx"
 10	"golang.org/x/crypto/ssh"
 11)
 12
 13// User represents a user.
 14type User struct {
 15	username string
 16	db       *sqlx.DB
 17}
 18
 19var _ backend.User = (*User)(nil)
 20
 21// IsAdmin returns whether the user is an admin.
 22//
 23// It implements backend.User.
 24func (u *User) IsAdmin() bool {
 25	var admin bool
 26	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 27		return tx.Get(&admin, "SELECT admin FROM user WHERE username = ?", u.username)
 28	}); err != nil {
 29		return false
 30	}
 31
 32	return admin
 33}
 34
 35// PublicKeys returns the user's public keys.
 36//
 37// It implements backend.User.
 38func (u *User) PublicKeys() []ssh.PublicKey {
 39	var keys []ssh.PublicKey
 40	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 41		var keyStrings []string
 42		if err := tx.Select(&keyStrings, `SELECT public_key
 43			FROM public_key
 44			INNER JOIN user ON user.id = public_key.user_id
 45			WHERE user.username = ?;`, u.username); err != nil {
 46			return err
 47		}
 48
 49		for _, keyString := range keyStrings {
 50			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
 51			if err != nil {
 52				return err
 53			}
 54			keys = append(keys, key)
 55		}
 56
 57		return nil
 58	}); err != nil {
 59		return nil
 60	}
 61
 62	return keys
 63}
 64
 65// Username returns the user's username.
 66//
 67// It implements backend.User.
 68func (u *User) Username() string {
 69	return u.username
 70}
 71
 72// AccessLevel returns the access level of a user for a repository.
 73//
 74// It implements backend.Backend.
 75func (d *SqliteBackend) AccessLevel(repo string, username string) backend.AccessLevel {
 76	anon := d.AnonAccess()
 77	user, _ := d.User(username)
 78	// If the user is an admin, they have admin access.
 79	if user != nil && user.IsAdmin() {
 80		return backend.AdminAccess
 81	}
 82
 83	// If the repository exists, check if the user is a collaborator.
 84	r, _ := d.Repository(repo)
 85	if r != nil {
 86		// If the user is a collaborator, they have read/write access.
 87		isCollab, _ := d.IsCollaborator(repo, username)
 88		if isCollab {
 89			if anon > backend.ReadWriteAccess {
 90				return anon
 91			}
 92			return backend.ReadWriteAccess
 93		}
 94
 95		// If the repository is private, the user has no access.
 96		if r.IsPrivate() {
 97			return backend.NoAccess
 98		}
 99
100		// Otherwise, the user has read-only access.
101		return backend.ReadOnlyAccess
102	}
103
104	// If the repository doesn't exist, the user has read/write access.
105	if user != nil {
106		// If the repository doesn't exist, the user has read/write access.
107		if anon > backend.ReadWriteAccess {
108			return anon
109		}
110
111		return backend.ReadWriteAccess
112	}
113
114	// If the user doesn't exist, give them the anonymous access level.
115	return anon
116}
117
118// AccessLevelByPublicKey returns the access level of a user's public key for a repository.
119//
120// It implements backend.Backend.
121func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel {
122	ak := backend.MarshalAuthorizedKey(pk)
123	if strings.HasPrefix(d.cfg.InternalPublicKey, ak) {
124		return backend.AdminAccess
125	}
126	for _, k := range d.cfg.InitialAdminKeys {
127		if k == ak {
128			return backend.AdminAccess
129		}
130	}
131
132	user, _ := d.UserByPublicKey(pk)
133	if user != nil {
134		return d.AccessLevel(repo, user.Username())
135	}
136
137	return d.AccessLevel(repo, "")
138}
139
140// AddPublicKey adds a public key to a user.
141//
142// It implements backend.Backend.
143func (d *SqliteBackend) AddPublicKey(username string, pk ssh.PublicKey) error {
144	username = strings.ToLower(username)
145	if err := utils.ValidateUsername(username); err != nil {
146		return err
147	}
148
149	return wrapDbErr(
150		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
151			var userID int
152			if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
153				return err
154			}
155
156			_, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
157			VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, backend.MarshalAuthorizedKey(pk))
158			return err
159		}),
160	)
161}
162
163// CreateUser creates a new user.
164//
165// It implements backend.Backend.
166func (d *SqliteBackend) CreateUser(username string, opts backend.UserOptions) (backend.User, error) {
167	username = strings.ToLower(username)
168	if err := utils.ValidateUsername(username); err != nil {
169		return nil, err
170	}
171
172	var user *User
173	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
174		stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
175		if err != nil {
176			return err
177		}
178
179		defer stmt.Close() // nolint: errcheck
180		r, err := stmt.Exec(username, opts.Admin)
181		if err != nil {
182			return err
183		}
184
185		if len(opts.PublicKeys) > 0 {
186			userID, err := r.LastInsertId()
187			if err != nil {
188				logger.Error("error getting last insert id")
189				return err
190			}
191
192			for _, pk := range opts.PublicKeys {
193				stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
194					VALUES (?, ?, CURRENT_TIMESTAMP);`)
195				if err != nil {
196					return err
197				}
198
199				defer stmt.Close() // nolint: errcheck
200				if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil {
201					return err
202				}
203			}
204		}
205
206		user = &User{
207			db:       d.db,
208			username: username,
209		}
210		return nil
211	}); err != nil {
212		return nil, wrapDbErr(err)
213	}
214
215	return user, nil
216}
217
218// DeleteUser deletes a user.
219//
220// It implements backend.Backend.
221func (d *SqliteBackend) DeleteUser(username string) error {
222	username = strings.ToLower(username)
223	if err := utils.ValidateUsername(username); err != nil {
224		return err
225	}
226
227	return wrapDbErr(
228		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
229			_, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
230			return err
231		}),
232	)
233}
234
235// RemovePublicKey removes a public key from a user.
236//
237// It implements backend.Backend.
238func (d *SqliteBackend) RemovePublicKey(username string, pk ssh.PublicKey) error {
239	return wrapDbErr(
240		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
241			_, err := tx.Exec(`DELETE FROM public_key
242			WHERE user_id = (SELECT id FROM user WHERE username = ?)
243			AND public_key = ?;`, username, backend.MarshalAuthorizedKey(pk))
244			return err
245		}),
246	)
247}
248
249// ListPublicKeys lists the public keys of a user.
250func (d *SqliteBackend) ListPublicKeys(username string) ([]ssh.PublicKey, error) {
251	username = strings.ToLower(username)
252	if err := utils.ValidateUsername(username); err != nil {
253		return nil, err
254	}
255
256	keys := make([]ssh.PublicKey, 0)
257	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
258		var keyStrings []string
259		if err := tx.Select(&keyStrings, `SELECT public_key
260			FROM public_key
261			INNER JOIN user ON user.id = public_key.user_id
262			WHERE user.username = ?;`, username); err != nil {
263			return err
264		}
265
266		for _, keyString := range keyStrings {
267			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
268			if err != nil {
269				return err
270			}
271			keys = append(keys, key)
272		}
273
274		return nil
275	}); err != nil {
276		return nil, wrapDbErr(err)
277	}
278
279	return keys, nil
280}
281
282// SetUsername sets the username of a user.
283//
284// It implements backend.Backend.
285func (d *SqliteBackend) SetUsername(username string, newUsername string) error {
286	username = strings.ToLower(username)
287	if err := utils.ValidateUsername(username); err != nil {
288		return err
289	}
290
291	return wrapDbErr(
292		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
293			_, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
294			return err
295		}),
296	)
297}
298
299// SetAdmin sets the admin flag of a user.
300//
301// It implements backend.Backend.
302func (d *SqliteBackend) SetAdmin(username string, admin bool) error {
303	username = strings.ToLower(username)
304	if err := utils.ValidateUsername(username); err != nil {
305		return err
306	}
307
308	return wrapDbErr(
309		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
310			_, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
311			return err
312		}),
313	)
314}
315
316// User finds a user by username.
317//
318// It implements backend.Backend.
319func (d *SqliteBackend) User(username string) (backend.User, error) {
320	username = strings.ToLower(username)
321	if err := utils.ValidateUsername(username); err != nil {
322		return nil, err
323	}
324
325	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
326		return tx.Get(&username, "SELECT username FROM user WHERE username = ?", username)
327	}); err != nil {
328		return nil, wrapDbErr(err)
329	}
330
331	return &User{
332		db:       d.db,
333		username: username,
334	}, nil
335}
336
337// UserByPublicKey finds a user by public key.
338//
339// It implements backend.Backend.
340func (d *SqliteBackend) UserByPublicKey(pk ssh.PublicKey) (backend.User, error) {
341	var username string
342	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
343		return tx.Get(&username, `SELECT user.username
344			FROM public_key
345			INNER JOIN user ON user.id = public_key.user_id
346			WHERE public_key.public_key = ?;`, backend.MarshalAuthorizedKey(pk))
347	}); err != nil {
348		return nil, wrapDbErr(err)
349	}
350
351	return &User{
352		db:       d.db,
353		username: username,
354	}, nil
355}
356
357// Users returns all users.
358//
359// It implements backend.Backend.
360func (d *SqliteBackend) Users() ([]string, error) {
361	var users []string
362	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
363		return tx.Select(&users, "SELECT username FROM user")
364	}); err != nil {
365		return nil, wrapDbErr(err)
366	}
367
368	return users, nil
369}