user.go

  1package database
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/pkg/db"
  8	"github.com/charmbracelet/soft-serve/pkg/db/models"
  9	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 10	"github.com/charmbracelet/soft-serve/pkg/store"
 11	"github.com/charmbracelet/soft-serve/pkg/utils"
 12	"golang.org/x/crypto/ssh"
 13)
 14
 15type userStore struct{ *handleStore }
 16
 17var _ store.UserStore = (*userStore)(nil)
 18
 19// AddPublicKeyByUsername implements store.UserStore.
 20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
 21	username = strings.ToLower(username)
 22	if err := utils.ValidateHandle(username); err != nil {
 23		return err
 24	}
 25
 26	var userID int64
 27	if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT users.id FROM users
 28			INNER JOIN handles ON handles.id = users.handle_id
 29			WHERE handles.handle = ?;`), username); err != nil {
 30		return err
 31	}
 32
 33	query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 34			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 35	ak := sshutils.MarshalAuthorizedKey(pk)
 36	_, err := tx.ExecContext(ctx, query, userID, ak)
 37
 38	return err
 39}
 40
 41// CreateUser implements store.UserStore.
 42func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
 43	handleID, err := s.CreateHandle(ctx, tx, username)
 44	if err != nil {
 45		return err
 46	}
 47
 48	query := tx.Rebind(`
 49		INSERT INTO
 50		  users (handle_id, admin, updated_at)
 51		VALUES
 52		  (?, ?, CURRENT_TIMESTAMP) RETURNING id;
 53	`)
 54
 55	var userID int64
 56	if err := tx.GetContext(ctx, &userID, query, handleID, isAdmin); err != nil {
 57		return err
 58	}
 59
 60	for _, pk := range pks {
 61		query := tx.Rebind(`
 62			INSERT INTO
 63			  public_keys (user_id, public_key, updated_at)
 64			VALUES
 65			  (?, ?, CURRENT_TIMESTAMP);
 66		`)
 67		ak := sshutils.MarshalAuthorizedKey(pk)
 68		_, err := tx.ExecContext(ctx, query, userID, ak)
 69		if err != nil {
 70			return err
 71		}
 72	}
 73
 74	return nil
 75}
 76
 77// DeleteUserByUsername implements store.UserStore.
 78func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
 79	username = strings.ToLower(username)
 80	if err := utils.ValidateHandle(username); err != nil {
 81		return err
 82	}
 83
 84	query := tx.Rebind(`DELETE FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
 85	_, err := tx.ExecContext(ctx, query, username)
 86	return err
 87}
 88
 89// GetUserByID implements store.UserStore.
 90func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
 91	var m models.User
 92	query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
 93	err := tx.GetContext(ctx, &m, query, id)
 94	return m, err
 95}
 96
 97// FindUserByPublicKey implements store.UserStore.
 98func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
 99	var m models.User
100	query := tx.Rebind(`SELECT users.*
101			FROM users
102			INNER JOIN public_keys ON users.id = public_keys.user_id
103			WHERE public_keys.public_key = ?;`)
104	err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
105	return m, err
106}
107
108// FindUserByUsername implements store.UserStore.
109func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
110	username = strings.ToLower(username)
111	if err := utils.ValidateHandle(username); err != nil {
112		return models.User{}, err
113	}
114
115	var m models.User
116	query := tx.Rebind(`SELECT * FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
117	err := tx.GetContext(ctx, &m, query, username)
118	return m, err
119}
120
121// FindUserByAccessToken implements store.UserStore.
122func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
123	var m models.User
124	query := tx.Rebind(`SELECT users.*
125			FROM users
126			INNER JOIN access_tokens ON users.id = access_tokens.user_id
127			WHERE access_tokens.token = ?;`)
128	err := tx.GetContext(ctx, &m, query, token)
129	return m, err
130}
131
132// GetAllUsers implements store.UserStore.
133func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
134	var ms []models.User
135	query := tx.Rebind(`SELECT * FROM users;`)
136	err := tx.SelectContext(ctx, &ms, query)
137	return ms, err
138}
139
140// ListPublicKeysByUserID implements store.UserStore..
141func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
142	var aks []string
143	query := tx.Rebind(`SELECT public_key FROM public_keys
144			WHERE user_id = ?
145			ORDER BY public_keys.id ASC;`)
146	err := tx.SelectContext(ctx, &aks, query, id)
147	if err != nil {
148		return nil, err
149	}
150
151	pks := make([]ssh.PublicKey, len(aks))
152	for i, ak := range aks {
153		pk, _, err := sshutils.ParseAuthorizedKey(ak)
154		if err != nil {
155			return nil, err
156		}
157		pks[i] = pk
158	}
159
160	return pks, nil
161}
162
163// ListPublicKeysByUsername implements store.UserStore.
164func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
165	username = strings.ToLower(username)
166	if err := utils.ValidateHandle(username); err != nil {
167		return nil, err
168	}
169
170	var aks []string
171	query := tx.Rebind(`SELECT public_key FROM public_keys
172			INNER JOIN users ON users.id = public_keys.user_id
173			WHERE users.handle_id = (SELECT id FROM handles WHERE handle = ?)
174			ORDER BY public_keys.id ASC;`)
175	err := tx.SelectContext(ctx, &aks, query, username)
176	if err != nil {
177		return nil, err
178	}
179
180	pks := make([]ssh.PublicKey, len(aks))
181	for i, ak := range aks {
182		pk, _, err := sshutils.ParseAuthorizedKey(ak)
183		if err != nil {
184			return nil, err
185		}
186		pks[i] = pk
187	}
188
189	return pks, nil
190}
191
192// RemovePublicKeyByUsername implements store.UserStore.
193func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
194	username = strings.ToLower(username)
195	if err := utils.ValidateHandle(username); err != nil {
196		return err
197	}
198
199	query := tx.Rebind(`DELETE FROM public_keys
200			WHERE user_id = (SELECT id FROM users WHERE handle_id = (
201				SELECT id FROM handles WHERE handle = ?
202			))
203			AND public_key = ?;`)
204	_, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
205	return err
206}
207
208// SetAdminByUsername implements store.UserStore.
209func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
210	username = strings.ToLower(username)
211	if err := utils.ValidateHandle(username); err != nil {
212		return err
213	}
214
215	query := tx.Rebind(`UPDATE users SET admin = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?)`)
216	_, err := tx.ExecContext(ctx, query, isAdmin, username)
217	return err
218}
219
220// SetUsernameByUsername implements store.UserStore.
221func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
222	username = strings.ToLower(username)
223	if err := utils.ValidateHandle(username); err != nil {
224		return err
225	}
226
227	newUsername = strings.ToLower(newUsername)
228	if err := utils.ValidateHandle(newUsername); err != nil {
229		return err
230	}
231
232	query := tx.Rebind(`UPDATE handles SET handle = ? WHERE handle = ?;`)
233	_, err := tx.ExecContext(ctx, query, newUsername, username)
234	return err
235}
236
237// SetUserPassword implements store.UserStore.
238func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
239	query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
240	_, err := tx.ExecContext(ctx, query, password, userID)
241	return err
242}
243
244// SetUserPasswordByUsername implements store.UserStore.
245func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
246	username = strings.ToLower(username)
247	if err := utils.ValidateHandle(username); err != nil {
248		return err
249	}
250
251	query := tx.Rebind(`UPDATE users SET password = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
252	_, err := tx.ExecContext(ctx, query, password, username)
253	return err
254}
255
256// AddUserEmail implements store.UserStore.
257func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error {
258	query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at)
259			VALUES (?, ?, ?, CURRENT_TIMESTAMP);`)
260	_, err := tx.ExecContext(ctx, query, userID, email, isPrimary)
261	return err
262}
263
264// ListUserEmails implements store.UserStore.
265func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int64) ([]models.UserEmail, error) {
266	var ms []models.UserEmail
267	query := tx.Rebind(`SELECT * FROM user_emails WHERE user_id = ?;`)
268	err := tx.SelectContext(ctx, &ms, query, userID)
269	return ms, err
270}
271
272// UpdateUserEmail implements store.UserStore.
273func (*userStore) UpdateUserEmail(ctx context.Context, tx db.Handler, userID int64, oldEmail string, newEmail string, isPrimary bool) error {
274	query := tx.Rebind(`UPDATE user_emails SET email = ?, is_primary = ?, updated_at = CURRENT_TIMESTAMP WHERE user_id = ? AND email = ?;`)
275	_, err := tx.ExecContext(ctx, query, newEmail, isPrimary, userID, oldEmail)
276	return err
277}
278
279// DeleteUserEmail implements store.UserStore.
280func (*userStore) DeleteUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error {
281	query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ?;`)
282	_, err := tx.ExecContext(ctx, query, userID, email)
283	return err
284}