user.go

  1package database
  2
  3import (
  4	"context"
  5	"fmt"
  6	"strings"
  7
  8	"github.com/charmbracelet/soft-serve/pkg/db"
  9	"github.com/charmbracelet/soft-serve/pkg/db/models"
 10	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 11	"github.com/charmbracelet/soft-serve/pkg/store"
 12	"github.com/charmbracelet/soft-serve/pkg/utils"
 13	"golang.org/x/crypto/ssh"
 14)
 15
 16type userStore struct{ *handleStore }
 17
 18var _ store.UserStore = (*userStore)(nil)
 19
 20// AddPublicKeyByUsername implements store.UserStore.
 21func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
 22	username = strings.ToLower(username)
 23	if err := utils.ValidateHandle(username); err != nil {
 24		return err
 25	}
 26
 27	var userID int64
 28	if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT users.id FROM users
 29			INNER JOIN handles ON handles.id = users.handle_id
 30			WHERE handles.handle = ?;`), username); err != nil {
 31		return err
 32	}
 33
 34	query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 35			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 36	ak := sshutils.MarshalAuthorizedKey(pk)
 37	_, err := tx.ExecContext(ctx, query, userID, ak)
 38
 39	return err
 40}
 41
 42// CreateUser implements store.UserStore.
 43func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey, emails []string) error {
 44	handleID, err := s.CreateHandle(ctx, tx, username)
 45	if err != nil {
 46		return err
 47	}
 48
 49	query := tx.Rebind(`
 50		INSERT INTO
 51		  users (handle_id, admin, updated_at)
 52		VALUES
 53		  (?, ?, CURRENT_TIMESTAMP) RETURNING id;
 54	`)
 55
 56	var userID int64
 57	if err := tx.GetContext(ctx, &userID, query, handleID, isAdmin); err != nil {
 58		return err
 59	}
 60
 61	for _, pk := range pks {
 62		query := tx.Rebind(`
 63			INSERT INTO
 64			  public_keys (user_id, public_key, updated_at)
 65			VALUES
 66			  (?, ?, CURRENT_TIMESTAMP);
 67		`)
 68		ak := sshutils.MarshalAuthorizedKey(pk)
 69		_, err := tx.ExecContext(ctx, query, userID, ak)
 70		if err != nil {
 71			return err
 72		}
 73	}
 74
 75	for i, e := range emails {
 76		if err := s.AddUserEmail(ctx, tx, userID, e, i == 0); err != nil {
 77			return err
 78		}
 79	}
 80
 81	return nil
 82}
 83
 84// DeleteUserByUsername implements store.UserStore.
 85func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
 86	username = strings.ToLower(username)
 87	if err := utils.ValidateHandle(username); err != nil {
 88		return err
 89	}
 90
 91	query := tx.Rebind(`DELETE FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
 92	_, err := tx.ExecContext(ctx, query, username)
 93	return err
 94}
 95
 96// GetUserByID implements store.UserStore.
 97func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
 98	var m models.User
 99	query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
100	err := tx.GetContext(ctx, &m, query, id)
101	return m, err
102}
103
104// FindUserByPublicKey implements store.UserStore.
105func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
106	var m models.User
107	query := tx.Rebind(`SELECT users.*
108			FROM users
109			INNER JOIN public_keys ON users.id = public_keys.user_id
110			WHERE public_keys.public_key = ?;`)
111	err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
112	return m, err
113}
114
115// FindUserByUsername implements store.UserStore.
116func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
117	username = strings.ToLower(username)
118	if err := utils.ValidateHandle(username); err != nil {
119		return models.User{}, err
120	}
121
122	var m models.User
123	query := tx.Rebind(`SELECT * FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
124	err := tx.GetContext(ctx, &m, query, username)
125	return m, err
126}
127
128// FindUserByAccessToken implements store.UserStore.
129func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
130	var m models.User
131	query := tx.Rebind(`SELECT users.*
132			FROM users
133			INNER JOIN access_tokens ON users.id = access_tokens.user_id
134			WHERE access_tokens.token = ?;`)
135	err := tx.GetContext(ctx, &m, query, token)
136	return m, err
137}
138
139// GetAllUsers implements store.UserStore.
140func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
141	var ms []models.User
142	query := tx.Rebind(`SELECT * FROM users;`)
143	err := tx.SelectContext(ctx, &ms, query)
144	return ms, err
145}
146
147// ListPublicKeysByUserID implements store.UserStore..
148func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
149	var aks []string
150	query := tx.Rebind(`SELECT public_key FROM public_keys
151			WHERE user_id = ?
152			ORDER BY public_keys.id ASC;`)
153	err := tx.SelectContext(ctx, &aks, query, id)
154	if err != nil {
155		return nil, err
156	}
157
158	pks := make([]ssh.PublicKey, len(aks))
159	for i, ak := range aks {
160		pk, _, err := sshutils.ParseAuthorizedKey(ak)
161		if err != nil {
162			return nil, err
163		}
164		pks[i] = pk
165	}
166
167	return pks, nil
168}
169
170// ListPublicKeysByUsername implements store.UserStore.
171func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
172	username = strings.ToLower(username)
173	if err := utils.ValidateHandle(username); err != nil {
174		return nil, err
175	}
176
177	var aks []string
178	query := tx.Rebind(`SELECT public_key FROM public_keys
179			INNER JOIN users ON users.id = public_keys.user_id
180			WHERE users.handle_id = (SELECT id FROM handles WHERE handle = ?)
181			ORDER BY public_keys.id ASC;`)
182	err := tx.SelectContext(ctx, &aks, query, username)
183	if err != nil {
184		return nil, err
185	}
186
187	pks := make([]ssh.PublicKey, len(aks))
188	for i, ak := range aks {
189		pk, _, err := sshutils.ParseAuthorizedKey(ak)
190		if err != nil {
191			return nil, err
192		}
193		pks[i] = pk
194	}
195
196	return pks, nil
197}
198
199// RemovePublicKeyByUsername implements store.UserStore.
200func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
201	username = strings.ToLower(username)
202	if err := utils.ValidateHandle(username); err != nil {
203		return err
204	}
205
206	query := tx.Rebind(`DELETE FROM public_keys
207			WHERE user_id = (SELECT id FROM users WHERE handle_id = (
208				SELECT id FROM handles WHERE handle = ?
209			))
210			AND public_key = ?;`)
211	_, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
212	return err
213}
214
215// SetAdminByUsername implements store.UserStore.
216func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
217	username = strings.ToLower(username)
218	if err := utils.ValidateHandle(username); err != nil {
219		return err
220	}
221
222	query := tx.Rebind(`UPDATE users SET admin = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?)`)
223	_, err := tx.ExecContext(ctx, query, isAdmin, username)
224	return err
225}
226
227// SetUsernameByUsername implements store.UserStore.
228func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
229	username = strings.ToLower(username)
230	if err := utils.ValidateHandle(username); err != nil {
231		return err
232	}
233
234	newUsername = strings.ToLower(newUsername)
235	if err := utils.ValidateHandle(newUsername); err != nil {
236		return err
237	}
238
239	query := tx.Rebind(`UPDATE handles SET handle = ? WHERE handle = ?;`)
240	_, err := tx.ExecContext(ctx, query, newUsername, username)
241	return err
242}
243
244// SetUserPassword implements store.UserStore.
245func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
246	query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
247	_, err := tx.ExecContext(ctx, query, password, userID)
248	return err
249}
250
251// SetUserPasswordByUsername implements store.UserStore.
252func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
253	username = strings.ToLower(username)
254	if err := utils.ValidateHandle(username); err != nil {
255		return err
256	}
257
258	query := tx.Rebind(`UPDATE users SET password = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
259	_, err := tx.ExecContext(ctx, query, password, username)
260	return err
261}
262
263// AddUserEmail implements store.UserStore.
264func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error {
265	if err := utils.ValidateEmail(email); err != nil {
266		return err
267	}
268	query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at)
269			VALUES (?, ?, ?, CURRENT_TIMESTAMP);`)
270	_, err := tx.ExecContext(ctx, query, userID, email, isPrimary)
271	return err
272}
273
274// ListUserEmails implements store.UserStore.
275func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int64) ([]models.UserEmail, error) {
276	var ms []models.UserEmail
277	query := tx.Rebind(`SELECT * FROM user_emails WHERE user_id = ?;`)
278	err := tx.SelectContext(ctx, &ms, query, userID)
279	return ms, err
280}
281
282// RemoveUserEmail implements store.UserStore.
283func (*userStore) RemoveUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error {
284	var e models.UserEmail
285	query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ? RETURNING *;`)
286	if err := tx.GetContext(ctx, &e, query, userID, email); err != nil {
287		return err
288	}
289
290	if e.IsPrimary {
291		return fmt.Errorf("cannot remove primary email")
292	} else if e.ID == 0 {
293		return db.ErrRecordNotFound
294	}
295
296	return nil
297}
298
299// SetUserPrimaryEmail implements store.UserStore.
300func (*userStore) SetUserPrimaryEmail(ctx context.Context, tx db.Handler, userID int64, email string) error {
301	query := tx.Rebind(`UPDATE user_emails SET is_primary = FALSE WHERE user_id = ?;`)
302	_, err := tx.ExecContext(ctx, query, userID)
303	if err != nil {
304		return err
305	}
306
307	var emailID int64
308	query = tx.Rebind(`UPDATE user_emails SET is_primary = TRUE WHERE user_id = ? AND email = ? RETURNING id;`)
309	if err := tx.GetContext(ctx, &emailID, query, userID, email); err != nil {
310		return err
311	}
312
313	if emailID == 0 {
314		return db.ErrRecordNotFound
315	}
316
317	return nil
318}