1package database
2
3import (
4 "context"
5 "fmt"
6 "strings"
7
8 "github.com/charmbracelet/soft-serve/pkg/db"
9 "github.com/charmbracelet/soft-serve/pkg/db/models"
10 "github.com/charmbracelet/soft-serve/pkg/sshutils"
11 "github.com/charmbracelet/soft-serve/pkg/store"
12 "github.com/charmbracelet/soft-serve/pkg/utils"
13 "golang.org/x/crypto/ssh"
14)
15
16type userStore struct{ *handleStore }
17
18var _ store.UserStore = (*userStore)(nil)
19
20// AddPublicKeyByUsername implements store.UserStore.
21func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
22 username = strings.ToLower(username)
23 if err := utils.ValidateHandle(username); err != nil {
24 return err
25 }
26
27 var userID int64
28 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT users.id FROM users
29 INNER JOIN handles ON handles.id = users.handle_id
30 WHERE handles.handle = ?;`), username); err != nil {
31 return err
32 }
33
34 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
35 VALUES (?, ?, CURRENT_TIMESTAMP);`)
36 ak := sshutils.MarshalAuthorizedKey(pk)
37 _, err := tx.ExecContext(ctx, query, userID, ak)
38
39 return err
40}
41
42// CreateUser implements store.UserStore.
43func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey, emails []string) error {
44 handleID, err := s.CreateHandle(ctx, tx, username)
45 if err != nil {
46 return err
47 }
48
49 query := tx.Rebind(`
50 INSERT INTO
51 users (handle_id, admin, updated_at)
52 VALUES
53 (?, ?, CURRENT_TIMESTAMP) RETURNING id;
54 `)
55
56 var userID int64
57 if err := tx.GetContext(ctx, &userID, query, handleID, isAdmin); err != nil {
58 return err
59 }
60
61 for _, pk := range pks {
62 query := tx.Rebind(`
63 INSERT INTO
64 public_keys (user_id, public_key, updated_at)
65 VALUES
66 (?, ?, CURRENT_TIMESTAMP);
67 `)
68 ak := sshutils.MarshalAuthorizedKey(pk)
69 _, err := tx.ExecContext(ctx, query, userID, ak)
70 if err != nil {
71 return err
72 }
73 }
74
75 for i, e := range emails {
76 if err := s.AddUserEmail(ctx, tx, userID, e, i == 0); err != nil {
77 return err
78 }
79 }
80
81 return nil
82}
83
84// DeleteUserByUsername implements store.UserStore.
85func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
86 username = strings.ToLower(username)
87 if err := utils.ValidateHandle(username); err != nil {
88 return err
89 }
90
91 query := tx.Rebind(`DELETE FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
92 _, err := tx.ExecContext(ctx, query, username)
93 return err
94}
95
96// GetUserByID implements store.UserStore.
97func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
98 var m models.User
99 query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
100 err := tx.GetContext(ctx, &m, query, id)
101 return m, err
102}
103
104// FindUserByPublicKey implements store.UserStore.
105func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
106 var m models.User
107 query := tx.Rebind(`SELECT users.*
108 FROM users
109 INNER JOIN public_keys ON users.id = public_keys.user_id
110 WHERE public_keys.public_key = ?;`)
111 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
112 return m, err
113}
114
115// FindUserByUsername implements store.UserStore.
116func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
117 username = strings.ToLower(username)
118 if err := utils.ValidateHandle(username); err != nil {
119 return models.User{}, err
120 }
121
122 var m models.User
123 query := tx.Rebind(`SELECT * FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
124 err := tx.GetContext(ctx, &m, query, username)
125 return m, err
126}
127
128// FindUserByAccessToken implements store.UserStore.
129func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
130 var m models.User
131 query := tx.Rebind(`SELECT users.*
132 FROM users
133 INNER JOIN access_tokens ON users.id = access_tokens.user_id
134 WHERE access_tokens.token = ?;`)
135 err := tx.GetContext(ctx, &m, query, token)
136 return m, err
137}
138
139// GetAllUsers implements store.UserStore.
140func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
141 var ms []models.User
142 query := tx.Rebind(`SELECT * FROM users;`)
143 err := tx.SelectContext(ctx, &ms, query)
144 return ms, err
145}
146
147// ListPublicKeysByUserID implements store.UserStore..
148func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
149 var aks []string
150 query := tx.Rebind(`SELECT public_key FROM public_keys
151 WHERE user_id = ?
152 ORDER BY public_keys.id ASC;`)
153 err := tx.SelectContext(ctx, &aks, query, id)
154 if err != nil {
155 return nil, err
156 }
157
158 pks := make([]ssh.PublicKey, len(aks))
159 for i, ak := range aks {
160 pk, _, err := sshutils.ParseAuthorizedKey(ak)
161 if err != nil {
162 return nil, err
163 }
164 pks[i] = pk
165 }
166
167 return pks, nil
168}
169
170// ListPublicKeysByUsername implements store.UserStore.
171func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
172 username = strings.ToLower(username)
173 if err := utils.ValidateHandle(username); err != nil {
174 return nil, err
175 }
176
177 var aks []string
178 query := tx.Rebind(`SELECT public_key FROM public_keys
179 INNER JOIN users ON users.id = public_keys.user_id
180 WHERE users.handle_id = (SELECT id FROM handles WHERE handle = ?)
181 ORDER BY public_keys.id ASC;`)
182 err := tx.SelectContext(ctx, &aks, query, username)
183 if err != nil {
184 return nil, err
185 }
186
187 pks := make([]ssh.PublicKey, len(aks))
188 for i, ak := range aks {
189 pk, _, err := sshutils.ParseAuthorizedKey(ak)
190 if err != nil {
191 return nil, err
192 }
193 pks[i] = pk
194 }
195
196 return pks, nil
197}
198
199// RemovePublicKeyByUsername implements store.UserStore.
200func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
201 username = strings.ToLower(username)
202 if err := utils.ValidateHandle(username); err != nil {
203 return err
204 }
205
206 query := tx.Rebind(`DELETE FROM public_keys
207 WHERE user_id = (SELECT id FROM users WHERE handle_id = (
208 SELECT id FROM handles WHERE handle = ?
209 ))
210 AND public_key = ?;`)
211 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
212 return err
213}
214
215// SetAdminByUsername implements store.UserStore.
216func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
217 username = strings.ToLower(username)
218 if err := utils.ValidateHandle(username); err != nil {
219 return err
220 }
221
222 query := tx.Rebind(`UPDATE users SET admin = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?)`)
223 _, err := tx.ExecContext(ctx, query, isAdmin, username)
224 return err
225}
226
227// SetUsernameByUsername implements store.UserStore.
228func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
229 username = strings.ToLower(username)
230 if err := utils.ValidateHandle(username); err != nil {
231 return err
232 }
233
234 newUsername = strings.ToLower(newUsername)
235 if err := utils.ValidateHandle(newUsername); err != nil {
236 return err
237 }
238
239 query := tx.Rebind(`UPDATE handles SET handle = ? WHERE handle = ?;`)
240 _, err := tx.ExecContext(ctx, query, newUsername, username)
241 return err
242}
243
244// SetUserPassword implements store.UserStore.
245func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
246 query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
247 _, err := tx.ExecContext(ctx, query, password, userID)
248 return err
249}
250
251// SetUserPasswordByUsername implements store.UserStore.
252func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
253 username = strings.ToLower(username)
254 if err := utils.ValidateHandle(username); err != nil {
255 return err
256 }
257
258 query := tx.Rebind(`UPDATE users SET password = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
259 _, err := tx.ExecContext(ctx, query, password, username)
260 return err
261}
262
263// AddUserEmail implements store.UserStore.
264func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error {
265 if err := utils.ValidateEmail(email); err != nil {
266 return err
267 }
268 query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at)
269 VALUES (?, ?, ?, CURRENT_TIMESTAMP);`)
270 _, err := tx.ExecContext(ctx, query, userID, email, isPrimary)
271 return err
272}
273
274// ListUserEmails implements store.UserStore.
275func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int64) ([]models.UserEmail, error) {
276 var ms []models.UserEmail
277 query := tx.Rebind(`SELECT * FROM user_emails WHERE user_id = ?;`)
278 err := tx.SelectContext(ctx, &ms, query, userID)
279 return ms, err
280}
281
282// RemoveUserEmail implements store.UserStore.
283func (*userStore) RemoveUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error {
284 var e models.UserEmail
285 query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ? RETURNING *;`)
286 if err := tx.GetContext(ctx, &e, query, userID, email); err != nil {
287 return err
288 }
289
290 if e.IsPrimary {
291 return fmt.Errorf("cannot remove primary email")
292 } else if e.ID == 0 {
293 return db.ErrRecordNotFound
294 }
295
296 return nil
297}
298
299// SetUserPrimaryEmail implements store.UserStore.
300func (*userStore) SetUserPrimaryEmail(ctx context.Context, tx db.Handler, userID int64, email string) error {
301 query := tx.Rebind(`UPDATE user_emails SET is_primary = FALSE WHERE user_id = ?;`)
302 _, err := tx.ExecContext(ctx, query, userID)
303 if err != nil {
304 return err
305 }
306
307 var emailID int64
308 query = tx.Rebind(`UPDATE user_emails SET is_primary = TRUE WHERE user_id = ? AND email = ? RETURNING id;`)
309 if err := tx.GetContext(ctx, &emailID, query, userID, email); err != nil {
310 return err
311 }
312
313 if emailID == 0 {
314 return db.ErrRecordNotFound
315 }
316
317 return nil
318}