user.go

  1package database
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/pkg/db"
  8	"github.com/charmbracelet/soft-serve/pkg/db/models"
  9	"github.com/charmbracelet/soft-serve/pkg/sshutils"
 10	"github.com/charmbracelet/soft-serve/pkg/store"
 11	"github.com/charmbracelet/soft-serve/pkg/utils"
 12	"golang.org/x/crypto/ssh"
 13)
 14
 15type userStore struct{}
 16
 17var _ store.UserStore = (*userStore)(nil)
 18
 19// AddPublicKeyByUsername implements store.UserStore.
 20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
 21	username = strings.ToLower(username)
 22	if err := utils.ValidateUsername(username); err != nil {
 23		return err //nolint:wrapcheck
 24	}
 25
 26	var userID int64
 27	if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
 28		return err //nolint:wrapcheck
 29	}
 30
 31	query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 32			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 33	ak := sshutils.MarshalAuthorizedKey(pk)
 34	_, err := tx.ExecContext(ctx, query, userID, ak)
 35
 36	return err //nolint:wrapcheck
 37}
 38
 39// CreateUser implements store.UserStore.
 40func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
 41	username = strings.ToLower(username)
 42	if err := utils.ValidateUsername(username); err != nil {
 43		return err //nolint:wrapcheck
 44	}
 45
 46	query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
 47			VALUES (?, ?, CURRENT_TIMESTAMP) RETURNING id;`)
 48
 49	var userID int64
 50	if err := tx.GetContext(ctx, &userID, query, username, isAdmin); err != nil {
 51		return err //nolint:wrapcheck
 52	}
 53
 54	for _, pk := range pks {
 55		query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 56			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 57		ak := sshutils.MarshalAuthorizedKey(pk)
 58		_, err := tx.ExecContext(ctx, query, userID, ak)
 59		if err != nil {
 60			return err //nolint:wrapcheck
 61		}
 62	}
 63
 64	return nil
 65}
 66
 67// DeleteUserByUsername implements store.UserStore.
 68func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
 69	username = strings.ToLower(username)
 70	if err := utils.ValidateUsername(username); err != nil {
 71		return err //nolint:wrapcheck
 72	}
 73
 74	query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
 75	_, err := tx.ExecContext(ctx, query, username)
 76	return err //nolint:wrapcheck
 77}
 78
 79// GetUserByID implements store.UserStore.
 80func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
 81	var m models.User
 82	query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
 83	err := tx.GetContext(ctx, &m, query, id)
 84	return m, err //nolint:wrapcheck
 85}
 86
 87// FindUserByPublicKey implements store.UserStore.
 88func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
 89	var m models.User
 90	query := tx.Rebind(`SELECT users.*
 91			FROM users
 92			INNER JOIN public_keys ON users.id = public_keys.user_id
 93			WHERE public_keys.public_key = ?;`)
 94	err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
 95	return m, err //nolint:wrapcheck
 96}
 97
 98// FindUserByUsername implements store.UserStore.
 99func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
100	username = strings.ToLower(username)
101	if err := utils.ValidateUsername(username); err != nil {
102		return models.User{}, err //nolint:wrapcheck
103	}
104
105	var m models.User
106	query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
107	err := tx.GetContext(ctx, &m, query, username)
108	return m, err //nolint:wrapcheck
109}
110
111// FindUserByAccessToken implements store.UserStore.
112func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
113	var m models.User
114	query := tx.Rebind(`SELECT users.*
115			FROM users
116			INNER JOIN access_tokens ON users.id = access_tokens.user_id
117			WHERE access_tokens.token = ?;`)
118	err := tx.GetContext(ctx, &m, query, token)
119	return m, err //nolint:wrapcheck
120}
121
122// GetAllUsers implements store.UserStore.
123func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
124	var ms []models.User
125	query := tx.Rebind(`SELECT * FROM users;`)
126	err := tx.SelectContext(ctx, &ms, query)
127	return ms, err //nolint:wrapcheck
128}
129
130// ListPublicKeysByUserID implements store.UserStore..
131func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
132	var aks []string
133	query := tx.Rebind(`SELECT public_key FROM public_keys
134			WHERE user_id = ?
135			ORDER BY public_keys.id ASC;`)
136	err := tx.SelectContext(ctx, &aks, query, id)
137	if err != nil {
138		return nil, err //nolint:wrapcheck
139	}
140
141	pks := make([]ssh.PublicKey, len(aks))
142	for i, ak := range aks {
143		pk, _, err := sshutils.ParseAuthorizedKey(ak)
144		if err != nil {
145			return nil, err //nolint:wrapcheck
146		}
147		pks[i] = pk
148	}
149
150	return pks, nil
151}
152
153// ListPublicKeysByUsername implements store.UserStore.
154func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
155	username = strings.ToLower(username)
156	if err := utils.ValidateUsername(username); err != nil {
157		return nil, err //nolint:wrapcheck
158	}
159
160	var aks []string
161	query := tx.Rebind(`SELECT public_key FROM public_keys
162			INNER JOIN users ON users.id = public_keys.user_id
163			WHERE users.username = ?
164			ORDER BY public_keys.id ASC;`)
165	err := tx.SelectContext(ctx, &aks, query, username)
166	if err != nil {
167		return nil, err //nolint:wrapcheck
168	}
169
170	pks := make([]ssh.PublicKey, len(aks))
171	for i, ak := range aks {
172		pk, _, err := sshutils.ParseAuthorizedKey(ak)
173		if err != nil {
174			return nil, err //nolint:wrapcheck
175		}
176		pks[i] = pk
177	}
178
179	return pks, nil
180}
181
182// RemovePublicKeyByUsername implements store.UserStore.
183func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
184	username = strings.ToLower(username)
185	if err := utils.ValidateUsername(username); err != nil {
186		return err //nolint:wrapcheck
187	}
188
189	query := tx.Rebind(`DELETE FROM public_keys
190			WHERE user_id = (SELECT id FROM users WHERE username = ?)
191			AND public_key = ?;`)
192	_, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
193	return err //nolint:wrapcheck
194}
195
196// SetAdminByUsername implements store.UserStore.
197func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
198	username = strings.ToLower(username)
199	if err := utils.ValidateUsername(username); err != nil {
200		return err //nolint:wrapcheck
201	}
202
203	query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
204	_, err := tx.ExecContext(ctx, query, isAdmin, username)
205	return err //nolint:wrapcheck
206}
207
208// SetUsernameByUsername implements store.UserStore.
209func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
210	username = strings.ToLower(username)
211	if err := utils.ValidateUsername(username); err != nil {
212		return err //nolint:wrapcheck
213	}
214
215	newUsername = strings.ToLower(newUsername)
216	if err := utils.ValidateUsername(newUsername); err != nil {
217		return err //nolint:wrapcheck
218	}
219
220	query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
221	_, err := tx.ExecContext(ctx, query, newUsername, username)
222	return err //nolint:wrapcheck
223}
224
225// SetUserPassword implements store.UserStore.
226func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
227	query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
228	_, err := tx.ExecContext(ctx, query, password, userID)
229	return err //nolint:wrapcheck
230}
231
232// SetUserPasswordByUsername implements store.UserStore.
233func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
234	username = strings.ToLower(username)
235	if err := utils.ValidateUsername(username); err != nil {
236		return err //nolint:wrapcheck
237	}
238
239	query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`)
240	_, err := tx.ExecContext(ctx, query, password, username)
241	return err //nolint:wrapcheck
242}