1package sqlite
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/server/backend"
  8	"github.com/charmbracelet/soft-serve/server/utils"
  9	"github.com/jmoiron/sqlx"
 10	"golang.org/x/crypto/ssh"
 11)
 12
 13// User represents a user.
 14type User struct {
 15	username string
 16	db       *sqlx.DB
 17}
 18
 19var _ backend.User = (*User)(nil)
 20
 21// IsAdmin returns whether the user is an admin.
 22//
 23// It implements backend.User.
 24func (u *User) IsAdmin() bool {
 25	var admin bool
 26	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 27		return tx.Get(&admin, "SELECT admin FROM user WHERE username = ?", u.username)
 28	}); err != nil {
 29		return false
 30	}
 31
 32	return admin
 33}
 34
 35// PublicKeys returns the user's public keys.
 36//
 37// It implements backend.User.
 38func (u *User) PublicKeys() []ssh.PublicKey {
 39	var keys []ssh.PublicKey
 40	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 41		var keyStrings []string
 42		if err := tx.Select(&keyStrings, `SELECT public_key
 43			FROM public_key
 44			INNER JOIN user ON user.id = public_key.user_id
 45			WHERE user.username = ?;`, u.username); err != nil {
 46			return err
 47		}
 48
 49		for _, keyString := range keyStrings {
 50			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
 51			if err != nil {
 52				return err
 53			}
 54			keys = append(keys, key)
 55		}
 56
 57		return nil
 58	}); err != nil {
 59		return nil
 60	}
 61
 62	return keys
 63}
 64
 65// Username returns the user's username.
 66//
 67// It implements backend.User.
 68func (u *User) Username() string {
 69	return u.username
 70}
 71
 72// AccessLevel returns the access level of a user for a repository.
 73//
 74// It implements backend.Backend.
 75func (d *SqliteBackend) AccessLevel(repo string, username string) backend.AccessLevel {
 76	anon := d.AnonAccess()
 77	user, _ := d.User(username)
 78	// If the user is an admin, they have admin access.
 79	if user != nil && user.IsAdmin() {
 80		return backend.AdminAccess
 81	}
 82
 83	// If the repository exists, check if the user is a collaborator.
 84	r, _ := d.Repository(repo)
 85	if r != nil {
 86		// If the user is a collaborator, they have read/write access.
 87		isCollab, _ := d.IsCollaborator(repo, username)
 88		if isCollab {
 89			if anon > backend.ReadWriteAccess {
 90				return anon
 91			}
 92			return backend.ReadWriteAccess
 93		}
 94
 95		// If the repository is private, the user has no access.
 96		if r.IsPrivate() {
 97			return backend.NoAccess
 98		}
 99
100		// Otherwise, the user has read-only access.
101		return backend.ReadOnlyAccess
102	}
103
104	if user != nil {
105		// If the repository doesn't exist, the user has read/write access.
106		if anon > backend.ReadWriteAccess {
107			return anon
108		}
109
110		return backend.ReadWriteAccess
111	}
112
113	// If the user doesn't exist, give them the anonymous access level.
114	return anon
115}
116
117// AccessLevelByPublicKey returns the access level of a user's public key for a repository.
118//
119// It implements backend.Backend.
120func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel {
121	for _, k := range append(d.cfg.InitialAdminKeys, d.cfg.InternalPublicKey) {
122		ik, _, err := backend.ParseAuthorizedKey(k)
123		if err == nil && backend.KeysEqual(pk, ik) {
124			return backend.AdminAccess
125		}
126	}
127
128	user, _ := d.UserByPublicKey(pk)
129	if user != nil {
130		return d.AccessLevel(repo, user.Username())
131	}
132
133	return d.AccessLevel(repo, "")
134}
135
136// AddPublicKey adds a public key to a user.
137//
138// It implements backend.Backend.
139func (d *SqliteBackend) AddPublicKey(username string, pk ssh.PublicKey) error {
140	username = strings.ToLower(username)
141	if err := utils.ValidateUsername(username); err != nil {
142		return err
143	}
144
145	return wrapDbErr(
146		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
147			var userID int
148			if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
149				return err
150			}
151
152			_, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
153			VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, backend.MarshalAuthorizedKey(pk))
154			return err
155		}),
156	)
157}
158
159// CreateUser creates a new user.
160//
161// It implements backend.Backend.
162func (d *SqliteBackend) CreateUser(username string, opts backend.UserOptions) (backend.User, error) {
163	username = strings.ToLower(username)
164	if err := utils.ValidateUsername(username); err != nil {
165		return nil, err
166	}
167
168	var user *User
169	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
170		stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
171		if err != nil {
172			return err
173		}
174
175		defer stmt.Close() // nolint: errcheck
176		r, err := stmt.Exec(username, opts.Admin)
177		if err != nil {
178			return err
179		}
180
181		if len(opts.PublicKeys) > 0 {
182			userID, err := r.LastInsertId()
183			if err != nil {
184				logger.Error("error getting last insert id")
185				return err
186			}
187
188			for _, pk := range opts.PublicKeys {
189				stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
190					VALUES (?, ?, CURRENT_TIMESTAMP);`)
191				if err != nil {
192					return err
193				}
194
195				defer stmt.Close() // nolint: errcheck
196				if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil {
197					return err
198				}
199			}
200		}
201
202		user = &User{
203			db:       d.db,
204			username: username,
205		}
206		return nil
207	}); err != nil {
208		return nil, wrapDbErr(err)
209	}
210
211	return user, nil
212}
213
214// DeleteUser deletes a user.
215//
216// It implements backend.Backend.
217func (d *SqliteBackend) DeleteUser(username string) error {
218	username = strings.ToLower(username)
219	if err := utils.ValidateUsername(username); err != nil {
220		return err
221	}
222
223	return wrapDbErr(
224		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
225			_, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
226			return err
227		}),
228	)
229}
230
231// RemovePublicKey removes a public key from a user.
232//
233// It implements backend.Backend.
234func (d *SqliteBackend) RemovePublicKey(username string, pk ssh.PublicKey) error {
235	return wrapDbErr(
236		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
237			_, err := tx.Exec(`DELETE FROM public_key
238			WHERE user_id = (SELECT id FROM user WHERE username = ?)
239			AND public_key = ?;`, username, backend.MarshalAuthorizedKey(pk))
240			return err
241		}),
242	)
243}
244
245// ListPublicKeys lists the public keys of a user.
246func (d *SqliteBackend) ListPublicKeys(username string) ([]ssh.PublicKey, error) {
247	username = strings.ToLower(username)
248	if err := utils.ValidateUsername(username); err != nil {
249		return nil, err
250	}
251
252	keys := make([]ssh.PublicKey, 0)
253	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
254		var keyStrings []string
255		if err := tx.Select(&keyStrings, `SELECT public_key
256			FROM public_key
257			INNER JOIN user ON user.id = public_key.user_id
258			WHERE user.username = ?;`, username); err != nil {
259			return err
260		}
261
262		for _, keyString := range keyStrings {
263			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
264			if err != nil {
265				return err
266			}
267			keys = append(keys, key)
268		}
269
270		return nil
271	}); err != nil {
272		return nil, wrapDbErr(err)
273	}
274
275	return keys, nil
276}
277
278// SetUsername sets the username of a user.
279//
280// It implements backend.Backend.
281func (d *SqliteBackend) SetUsername(username string, newUsername string) error {
282	username = strings.ToLower(username)
283	if err := utils.ValidateUsername(username); err != nil {
284		return err
285	}
286
287	return wrapDbErr(
288		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
289			_, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
290			return err
291		}),
292	)
293}
294
295// SetAdmin sets the admin flag of a user.
296//
297// It implements backend.Backend.
298func (d *SqliteBackend) SetAdmin(username string, admin bool) error {
299	username = strings.ToLower(username)
300	if err := utils.ValidateUsername(username); err != nil {
301		return err
302	}
303
304	return wrapDbErr(
305		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
306			_, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
307			return err
308		}),
309	)
310}
311
312// User finds a user by username.
313//
314// It implements backend.Backend.
315func (d *SqliteBackend) User(username string) (backend.User, error) {
316	username = strings.ToLower(username)
317	if err := utils.ValidateUsername(username); err != nil {
318		return nil, err
319	}
320
321	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
322		return tx.Get(&username, "SELECT username FROM user WHERE username = ?", username)
323	}); err != nil {
324		return nil, wrapDbErr(err)
325	}
326
327	return &User{
328		db:       d.db,
329		username: username,
330	}, nil
331}
332
333// UserByPublicKey finds a user by public key.
334//
335// It implements backend.Backend.
336func (d *SqliteBackend) UserByPublicKey(pk ssh.PublicKey) (backend.User, error) {
337	var username string
338	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
339		return tx.Get(&username, `SELECT user.username
340			FROM public_key
341			INNER JOIN user ON user.id = public_key.user_id
342			WHERE public_key.public_key = ?;`, backend.MarshalAuthorizedKey(pk))
343	}); err != nil {
344		return nil, wrapDbErr(err)
345	}
346
347	return &User{
348		db:       d.db,
349		username: username,
350	}, nil
351}
352
353// Users returns all users.
354//
355// It implements backend.Backend.
356func (d *SqliteBackend) Users() ([]string, error) {
357	var users []string
358	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
359		return tx.Select(&users, "SELECT username FROM user")
360	}); err != nil {
361		return nil, wrapDbErr(err)
362	}
363
364	return users, nil
365}