user.go

  1package sqlite
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/server/backend"
  8	"github.com/charmbracelet/soft-serve/server/utils"
  9	"github.com/jmoiron/sqlx"
 10	"golang.org/x/crypto/ssh"
 11)
 12
 13// User represents a user.
 14type User struct {
 15	username string
 16	db       *sqlx.DB
 17}
 18
 19var _ backend.User = (*User)(nil)
 20
 21// IsAdmin returns whether the user is an admin.
 22//
 23// It implements backend.User.
 24func (u *User) IsAdmin() bool {
 25	var admin bool
 26	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 27		return tx.Get(&admin, "SELECT admin FROM user WHERE username = ?", u.username)
 28	}); err != nil {
 29		return false
 30	}
 31
 32	return admin
 33}
 34
 35// PublicKeys returns the user's public keys.
 36//
 37// It implements backend.User.
 38func (u *User) PublicKeys() []ssh.PublicKey {
 39	var keys []ssh.PublicKey
 40	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 41		var keyStrings []string
 42		if err := tx.Select(&keyStrings, `SELECT public_key
 43			FROM public_key
 44			INNER JOIN user ON user.id = public_key.user_id
 45			WHERE user.username = ?;`, u.username); err != nil {
 46			return err
 47		}
 48
 49		for _, keyString := range keyStrings {
 50			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
 51			if err != nil {
 52				return err
 53			}
 54			keys = append(keys, key)
 55		}
 56
 57		return nil
 58	}); err != nil {
 59		return nil
 60	}
 61
 62	return keys
 63}
 64
 65// Username returns the user's username.
 66//
 67// It implements backend.User.
 68func (u *User) Username() string {
 69	return u.username
 70}
 71
 72// AccessLevel returns the access level of a user for a repository.
 73//
 74// It implements backend.Backend.
 75func (d *SqliteBackend) AccessLevel(repo string, username string) backend.AccessLevel {
 76	anon := d.AnonAccess()
 77	user, _ := d.User(username)
 78	// If the user is an admin, they have admin access.
 79	if user != nil && user.IsAdmin() {
 80		return backend.AdminAccess
 81	}
 82
 83	// If the repository exists, check if the user is a collaborator.
 84	r, _ := d.Repository(repo)
 85	if r != nil {
 86		// If the user is a collaborator, they have read/write access.
 87		isCollab, _ := d.IsCollaborator(repo, username)
 88		if isCollab {
 89			if anon > backend.ReadWriteAccess {
 90				return anon
 91			}
 92			return backend.ReadWriteAccess
 93		}
 94
 95		// If the repository is private, the user has no access.
 96		if r.IsPrivate() {
 97			return backend.NoAccess
 98		}
 99
100		// Otherwise, the user has read-only access.
101		return backend.ReadOnlyAccess
102	}
103
104	if user != nil {
105		// If the repository doesn't exist, the user has read/write access.
106		if anon > backend.ReadWriteAccess {
107			return anon
108		}
109
110		return backend.ReadWriteAccess
111	}
112
113	// If the user doesn't exist, give them the anonymous access level.
114	return anon
115}
116
117// AccessLevelByPublicKey returns the access level of a user's public key for a repository.
118//
119// It implements backend.Backend.
120func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel {
121	if ik, _, err := backend.ParseAuthorizedKey(d.cfg.InternalPublicKey); err == nil && backend.KeysEqual(ik, pk) {
122		return backend.AdminAccess
123	}
124	for _, k := range d.cfg.InitialAdminKeys {
125		ik, _, err := backend.ParseAuthorizedKey(k)
126		if err == nil && backend.KeysEqual(pk, ik) {
127			return backend.AdminAccess
128		}
129	}
130
131	user, _ := d.UserByPublicKey(pk)
132	if user != nil {
133		return d.AccessLevel(repo, user.Username())
134	}
135
136	return d.AccessLevel(repo, "")
137}
138
139// AddPublicKey adds a public key to a user.
140//
141// It implements backend.Backend.
142func (d *SqliteBackend) AddPublicKey(username string, pk ssh.PublicKey) error {
143	username = strings.ToLower(username)
144	if err := utils.ValidateUsername(username); err != nil {
145		return err
146	}
147
148	return wrapDbErr(
149		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
150			var userID int
151			if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
152				return err
153			}
154
155			_, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
156			VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, backend.MarshalAuthorizedKey(pk))
157			return err
158		}),
159	)
160}
161
162// CreateUser creates a new user.
163//
164// It implements backend.Backend.
165func (d *SqliteBackend) CreateUser(username string, opts backend.UserOptions) (backend.User, error) {
166	username = strings.ToLower(username)
167	if err := utils.ValidateUsername(username); err != nil {
168		return nil, err
169	}
170
171	var user *User
172	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
173		stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
174		if err != nil {
175			return err
176		}
177
178		defer stmt.Close() // nolint: errcheck
179		r, err := stmt.Exec(username, opts.Admin)
180		if err != nil {
181			return err
182		}
183
184		if len(opts.PublicKeys) > 0 {
185			userID, err := r.LastInsertId()
186			if err != nil {
187				logger.Error("error getting last insert id")
188				return err
189			}
190
191			for _, pk := range opts.PublicKeys {
192				stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
193					VALUES (?, ?, CURRENT_TIMESTAMP);`)
194				if err != nil {
195					return err
196				}
197
198				defer stmt.Close() // nolint: errcheck
199				if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil {
200					return err
201				}
202			}
203		}
204
205		user = &User{
206			db:       d.db,
207			username: username,
208		}
209		return nil
210	}); err != nil {
211		return nil, wrapDbErr(err)
212	}
213
214	return user, nil
215}
216
217// DeleteUser deletes a user.
218//
219// It implements backend.Backend.
220func (d *SqliteBackend) DeleteUser(username string) error {
221	username = strings.ToLower(username)
222	if err := utils.ValidateUsername(username); err != nil {
223		return err
224	}
225
226	return wrapDbErr(
227		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
228			_, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
229			return err
230		}),
231	)
232}
233
234// RemovePublicKey removes a public key from a user.
235//
236// It implements backend.Backend.
237func (d *SqliteBackend) RemovePublicKey(username string, pk ssh.PublicKey) error {
238	return wrapDbErr(
239		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
240			_, err := tx.Exec(`DELETE FROM public_key
241			WHERE user_id = (SELECT id FROM user WHERE username = ?)
242			AND public_key = ?;`, username, backend.MarshalAuthorizedKey(pk))
243			return err
244		}),
245	)
246}
247
248// ListPublicKeys lists the public keys of a user.
249func (d *SqliteBackend) ListPublicKeys(username string) ([]ssh.PublicKey, error) {
250	username = strings.ToLower(username)
251	if err := utils.ValidateUsername(username); err != nil {
252		return nil, err
253	}
254
255	keys := make([]ssh.PublicKey, 0)
256	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
257		var keyStrings []string
258		if err := tx.Select(&keyStrings, `SELECT public_key
259			FROM public_key
260			INNER JOIN user ON user.id = public_key.user_id
261			WHERE user.username = ?;`, username); err != nil {
262			return err
263		}
264
265		for _, keyString := range keyStrings {
266			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
267			if err != nil {
268				return err
269			}
270			keys = append(keys, key)
271		}
272
273		return nil
274	}); err != nil {
275		return nil, wrapDbErr(err)
276	}
277
278	return keys, nil
279}
280
281// SetUsername sets the username of a user.
282//
283// It implements backend.Backend.
284func (d *SqliteBackend) SetUsername(username string, newUsername string) error {
285	username = strings.ToLower(username)
286	if err := utils.ValidateUsername(username); err != nil {
287		return err
288	}
289
290	return wrapDbErr(
291		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
292			_, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
293			return err
294		}),
295	)
296}
297
298// SetAdmin sets the admin flag of a user.
299//
300// It implements backend.Backend.
301func (d *SqliteBackend) SetAdmin(username string, admin bool) error {
302	username = strings.ToLower(username)
303	if err := utils.ValidateUsername(username); err != nil {
304		return err
305	}
306
307	return wrapDbErr(
308		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
309			_, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
310			return err
311		}),
312	)
313}
314
315// User finds a user by username.
316//
317// It implements backend.Backend.
318func (d *SqliteBackend) User(username string) (backend.User, error) {
319	username = strings.ToLower(username)
320	if err := utils.ValidateUsername(username); err != nil {
321		return nil, err
322	}
323
324	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
325		return tx.Get(&username, "SELECT username FROM user WHERE username = ?", username)
326	}); err != nil {
327		return nil, wrapDbErr(err)
328	}
329
330	return &User{
331		db:       d.db,
332		username: username,
333	}, nil
334}
335
336// UserByPublicKey finds a user by public key.
337//
338// It implements backend.Backend.
339func (d *SqliteBackend) UserByPublicKey(pk ssh.PublicKey) (backend.User, error) {
340	var username string
341	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
342		return tx.Get(&username, `SELECT user.username
343			FROM public_key
344			INNER JOIN user ON user.id = public_key.user_id
345			WHERE public_key.public_key = ?;`, backend.MarshalAuthorizedKey(pk))
346	}); err != nil {
347		return nil, wrapDbErr(err)
348	}
349
350	return &User{
351		db:       d.db,
352		username: username,
353	}, nil
354}
355
356// Users returns all users.
357//
358// It implements backend.Backend.
359func (d *SqliteBackend) Users() ([]string, error) {
360	var users []string
361	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
362		return tx.Select(&users, "SELECT username FROM user")
363	}); err != nil {
364		return nil, wrapDbErr(err)
365	}
366
367	return users, nil
368}