1package sqlite
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/server/backend"
  8	"github.com/charmbracelet/soft-serve/server/utils"
  9	"github.com/jmoiron/sqlx"
 10	"golang.org/x/crypto/ssh"
 11)
 12
 13// User represents a user.
 14type User struct {
 15	username string
 16	db       *sqlx.DB
 17}
 18
 19var _ backend.User = (*User)(nil)
 20
 21// IsAdmin returns whether the user is an admin.
 22//
 23// It implements backend.User.
 24func (u *User) IsAdmin() bool {
 25	var admin bool
 26	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 27		return tx.Get(&admin, "SELECT admin FROM user WHERE username = ?", u.username)
 28	}); err != nil {
 29		return false
 30	}
 31
 32	return admin
 33}
 34
 35// PublicKeys returns the user's public keys.
 36//
 37// It implements backend.User.
 38func (u *User) PublicKeys() []ssh.PublicKey {
 39	var keys []ssh.PublicKey
 40	if err := wrapTx(u.db, context.Background(), func(tx *sqlx.Tx) error {
 41		var keyStrings []string
 42		if err := tx.Select(&keyStrings, `SELECT public_key
 43			FROM public_key
 44			INNER JOIN user ON user.id = public_key.user_id
 45			WHERE user.username = ?;`, u.username); err != nil {
 46			return err
 47		}
 48
 49		for _, keyString := range keyStrings {
 50			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
 51			if err != nil {
 52				return err
 53			}
 54			keys = append(keys, key)
 55		}
 56
 57		return nil
 58	}); err != nil {
 59		return nil
 60	}
 61
 62	return keys
 63}
 64
 65// Username returns the user's username.
 66//
 67// It implements backend.User.
 68func (u *User) Username() string {
 69	return u.username
 70}
 71
 72// AccessLevel returns the access level of a user for a repository.
 73//
 74// It implements backend.Backend.
 75func (d *SqliteBackend) AccessLevel(repo string, username string) backend.AccessLevel {
 76	anon := d.AnonAccess()
 77	user, _ := d.User(username)
 78	// If the user is an admin, they have admin access.
 79	if user != nil && user.IsAdmin() {
 80		return backend.AdminAccess
 81	}
 82
 83	// If the repository exists, check if the user is a collaborator.
 84	r, _ := d.Repository(repo)
 85	if r != nil {
 86		// If the user is a collaborator, they have read/write access.
 87		isCollab, _ := d.IsCollaborator(repo, username)
 88		if isCollab {
 89			if anon > backend.ReadWriteAccess {
 90				return anon
 91			}
 92			return backend.ReadWriteAccess
 93		}
 94
 95		// If the repository is private, the user has no access.
 96		if r.IsPrivate() {
 97			return backend.NoAccess
 98		}
 99
100		// Otherwise, the user has read-only access.
101		return backend.ReadOnlyAccess
102	}
103
104	// If the repository doesn't exist, the user has read/write access.
105	if user != nil {
106		// If the repository doesn't exist, the user has read/write access.
107		if anon > backend.ReadWriteAccess {
108			return anon
109		}
110
111		return backend.ReadWriteAccess
112	}
113
114	// If the user doesn't exist, give them the anonymous access level.
115	return anon
116}
117
118// AccessLevelByPublicKey returns the access level of a user's public key for a repository.
119//
120// It implements backend.Backend.
121func (d *SqliteBackend) AccessLevelByPublicKey(repo string, pk ssh.PublicKey) backend.AccessLevel {
122	ak := backend.MarshalAuthorizedKey(pk)
123	for _, k := range d.AdditionalAdmins {
124		if k == ak {
125			return backend.AdminAccess
126		}
127	}
128
129	user, _ := d.UserByPublicKey(pk)
130	if user != nil {
131		return d.AccessLevel(repo, user.Username())
132	}
133
134	return d.AccessLevel(repo, "")
135}
136
137// AddPublicKey adds a public key to a user.
138//
139// It implements backend.Backend.
140func (d *SqliteBackend) AddPublicKey(username string, pk ssh.PublicKey) error {
141	username = strings.ToLower(username)
142	if err := utils.ValidateUsername(username); err != nil {
143		return err
144	}
145
146	return wrapDbErr(
147		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
148			var userID int
149			if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
150				return err
151			}
152
153			_, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
154			VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, backend.MarshalAuthorizedKey(pk))
155			return err
156		}),
157	)
158}
159
160// CreateUser creates a new user.
161//
162// It implements backend.Backend.
163func (d *SqliteBackend) CreateUser(username string, opts backend.UserOptions) (backend.User, error) {
164	username = strings.ToLower(username)
165	if err := utils.ValidateUsername(username); err != nil {
166		return nil, err
167	}
168
169	var user *User
170	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
171		stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
172		if err != nil {
173			return err
174		}
175
176		defer stmt.Close() // nolint: errcheck
177		r, err := stmt.Exec(username, opts.Admin)
178		if err != nil {
179			return err
180		}
181
182		if len(opts.PublicKeys) > 0 {
183			userID, err := r.LastInsertId()
184			if err != nil {
185				logger.Error("error getting last insert id")
186				return err
187			}
188
189			for _, pk := range opts.PublicKeys {
190				stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
191					VALUES (?, ?, CURRENT_TIMESTAMP);`)
192				if err != nil {
193					return err
194				}
195
196				defer stmt.Close() // nolint: errcheck
197				if _, err := stmt.Exec(userID, backend.MarshalAuthorizedKey(pk)); err != nil {
198					return err
199				}
200			}
201		}
202
203		user = &User{
204			db:       d.db,
205			username: username,
206		}
207		return nil
208	}); err != nil {
209		return nil, wrapDbErr(err)
210	}
211
212	return user, nil
213}
214
215// DeleteUser deletes a user.
216//
217// It implements backend.Backend.
218func (d *SqliteBackend) DeleteUser(username string) error {
219	username = strings.ToLower(username)
220	if err := utils.ValidateUsername(username); err != nil {
221		return err
222	}
223
224	return wrapDbErr(
225		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
226			_, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
227			return err
228		}),
229	)
230}
231
232// RemovePublicKey removes a public key from a user.
233//
234// It implements backend.Backend.
235func (d *SqliteBackend) RemovePublicKey(username string, pk ssh.PublicKey) error {
236	return wrapDbErr(
237		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
238			_, err := tx.Exec(`DELETE FROM public_key
239			WHERE user_id = (SELECT id FROM user WHERE username = ?)
240			AND public_key = ?;`, username, backend.MarshalAuthorizedKey(pk))
241			return err
242		}),
243	)
244}
245
246// ListPublicKeys lists the public keys of a user.
247func (d *SqliteBackend) ListPublicKeys(username string) ([]ssh.PublicKey, error) {
248	username = strings.ToLower(username)
249	if err := utils.ValidateUsername(username); err != nil {
250		return nil, err
251	}
252
253	keys := make([]ssh.PublicKey, 0)
254	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
255		var keyStrings []string
256		if err := tx.Select(&keyStrings, `SELECT public_key
257			FROM public_key
258			INNER JOIN user ON user.id = public_key.user_id
259			WHERE user.username = ?;`, username); err != nil {
260			return err
261		}
262
263		for _, keyString := range keyStrings {
264			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
265			if err != nil {
266				return err
267			}
268			keys = append(keys, key)
269		}
270
271		return nil
272	}); err != nil {
273		return nil, wrapDbErr(err)
274	}
275
276	return keys, nil
277}
278
279// SetUsername sets the username of a user.
280//
281// It implements backend.Backend.
282func (d *SqliteBackend) SetUsername(username string, newUsername string) error {
283	username = strings.ToLower(username)
284	if err := utils.ValidateUsername(username); err != nil {
285		return err
286	}
287
288	return wrapDbErr(
289		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
290			_, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
291			return err
292		}),
293	)
294}
295
296// SetAdmin sets the admin flag of a user.
297//
298// It implements backend.Backend.
299func (d *SqliteBackend) SetAdmin(username string, admin bool) error {
300	username = strings.ToLower(username)
301	if err := utils.ValidateUsername(username); err != nil {
302		return err
303	}
304
305	return wrapDbErr(
306		wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
307			_, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
308			return err
309		}),
310	)
311}
312
313// User finds a user by username.
314//
315// It implements backend.Backend.
316func (d *SqliteBackend) User(username string) (backend.User, error) {
317	username = strings.ToLower(username)
318	if err := utils.ValidateUsername(username); err != nil {
319		return nil, err
320	}
321
322	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
323		return tx.Get(&username, "SELECT username FROM user WHERE username = ?", username)
324	}); err != nil {
325		return nil, wrapDbErr(err)
326	}
327
328	return &User{
329		db:       d.db,
330		username: username,
331	}, nil
332}
333
334// UserByPublicKey finds a user by public key.
335//
336// It implements backend.Backend.
337func (d *SqliteBackend) UserByPublicKey(pk ssh.PublicKey) (backend.User, error) {
338	var username string
339	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
340		return tx.Get(&username, `SELECT user.username
341			FROM public_key
342			INNER JOIN user ON user.id = public_key.user_id
343			WHERE public_key.public_key = ?;`, backend.MarshalAuthorizedKey(pk))
344	}); err != nil {
345		return nil, wrapDbErr(err)
346	}
347
348	return &User{
349		db:       d.db,
350		username: username,
351	}, nil
352}
353
354// Users returns all users.
355//
356// It implements backend.Backend.
357func (d *SqliteBackend) Users() ([]string, error) {
358	var users []string
359	if err := wrapTx(d.db, context.Background(), func(tx *sqlx.Tx) error {
360		return tx.Select(&users, "SELECT username FROM user")
361	}); err != nil {
362		return nil, wrapDbErr(err)
363	}
364
365	return users, nil
366}