1package database
2
3import (
4 "context"
5 "strings"
6
7 "github.com/charmbracelet/soft-serve/pkg/db"
8 "github.com/charmbracelet/soft-serve/pkg/db/models"
9 "github.com/charmbracelet/soft-serve/pkg/sshutils"
10 "github.com/charmbracelet/soft-serve/pkg/store"
11 "github.com/charmbracelet/soft-serve/pkg/utils"
12 "golang.org/x/crypto/ssh"
13)
14
15type userStore struct{ *handleStore }
16
17var _ store.UserStore = (*userStore)(nil)
18
19// AddPublicKeyByUsername implements store.UserStore.
20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
21 username = strings.ToLower(username)
22 if err := utils.ValidateHandle(username); err != nil {
23 return err
24 }
25
26 var userID int64
27 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT users.id FROM users
28 INNER JOIN handles ON handles.id = users.handle_id
29 WHERE handles.handle = ?;`), username); err != nil {
30 return err
31 }
32
33 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
34 VALUES (?, ?, CURRENT_TIMESTAMP);`)
35 ak := sshutils.MarshalAuthorizedKey(pk)
36 _, err := tx.ExecContext(ctx, query, userID, ak)
37
38 return err
39}
40
41// CreateUser implements store.UserStore.
42func (s *userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
43 handleID, err := s.CreateHandle(ctx, tx, username)
44 if err != nil {
45 return err
46 }
47
48 query := tx.Rebind(`
49 INSERT INTO
50 users (handle_id, admin, updated_at)
51 VALUES
52 (?, ?, CURRENT_TIMESTAMP) RETURNING id;
53 `)
54
55 var userID int64
56 if err := tx.GetContext(ctx, &userID, query, handleID, isAdmin); err != nil {
57 return err
58 }
59
60 for _, pk := range pks {
61 query := tx.Rebind(`
62 INSERT INTO
63 public_keys (user_id, public_key, updated_at)
64 VALUES
65 (?, ?, CURRENT_TIMESTAMP);
66 `)
67 ak := sshutils.MarshalAuthorizedKey(pk)
68 _, err := tx.ExecContext(ctx, query, userID, ak)
69 if err != nil {
70 return err
71 }
72 }
73
74 return nil
75}
76
77// DeleteUserByUsername implements store.UserStore.
78func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
79 username = strings.ToLower(username)
80 if err := utils.ValidateHandle(username); err != nil {
81 return err
82 }
83
84 query := tx.Rebind(`DELETE FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
85 _, err := tx.ExecContext(ctx, query, username)
86 return err
87}
88
89// GetUserByID implements store.UserStore.
90func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
91 var m models.User
92 query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
93 err := tx.GetContext(ctx, &m, query, id)
94 return m, err
95}
96
97// FindUserByPublicKey implements store.UserStore.
98func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
99 var m models.User
100 query := tx.Rebind(`SELECT users.*
101 FROM users
102 INNER JOIN public_keys ON users.id = public_keys.user_id
103 WHERE public_keys.public_key = ?;`)
104 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
105 return m, err
106}
107
108// FindUserByUsername implements store.UserStore.
109func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
110 username = strings.ToLower(username)
111 if err := utils.ValidateHandle(username); err != nil {
112 return models.User{}, err
113 }
114
115 var m models.User
116 query := tx.Rebind(`SELECT * FROM users WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
117 err := tx.GetContext(ctx, &m, query, username)
118 return m, err
119}
120
121// FindUserByAccessToken implements store.UserStore.
122func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
123 var m models.User
124 query := tx.Rebind(`SELECT users.*
125 FROM users
126 INNER JOIN access_tokens ON users.id = access_tokens.user_id
127 WHERE access_tokens.token = ?;`)
128 err := tx.GetContext(ctx, &m, query, token)
129 return m, err
130}
131
132// GetAllUsers implements store.UserStore.
133func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
134 var ms []models.User
135 query := tx.Rebind(`SELECT * FROM users;`)
136 err := tx.SelectContext(ctx, &ms, query)
137 return ms, err
138}
139
140// ListPublicKeysByUserID implements store.UserStore..
141func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
142 var aks []string
143 query := tx.Rebind(`SELECT public_key FROM public_keys
144 WHERE user_id = ?
145 ORDER BY public_keys.id ASC;`)
146 err := tx.SelectContext(ctx, &aks, query, id)
147 if err != nil {
148 return nil, err
149 }
150
151 pks := make([]ssh.PublicKey, len(aks))
152 for i, ak := range aks {
153 pk, _, err := sshutils.ParseAuthorizedKey(ak)
154 if err != nil {
155 return nil, err
156 }
157 pks[i] = pk
158 }
159
160 return pks, nil
161}
162
163// ListPublicKeysByUsername implements store.UserStore.
164func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
165 username = strings.ToLower(username)
166 if err := utils.ValidateHandle(username); err != nil {
167 return nil, err
168 }
169
170 var aks []string
171 query := tx.Rebind(`SELECT public_key FROM public_keys
172 INNER JOIN users ON users.id = public_keys.user_id
173 WHERE users.handle_id = (SELECT id FROM handles WHERE handle = ?)
174 ORDER BY public_keys.id ASC;`)
175 err := tx.SelectContext(ctx, &aks, query, username)
176 if err != nil {
177 return nil, err
178 }
179
180 pks := make([]ssh.PublicKey, len(aks))
181 for i, ak := range aks {
182 pk, _, err := sshutils.ParseAuthorizedKey(ak)
183 if err != nil {
184 return nil, err
185 }
186 pks[i] = pk
187 }
188
189 return pks, nil
190}
191
192// RemovePublicKeyByUsername implements store.UserStore.
193func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
194 username = strings.ToLower(username)
195 if err := utils.ValidateHandle(username); err != nil {
196 return err
197 }
198
199 query := tx.Rebind(`DELETE FROM public_keys
200 WHERE user_id = (SELECT id FROM users WHERE handle_id = (
201 SELECT id FROM handles WHERE handle = ?
202 ))
203 AND public_key = ?;`)
204 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
205 return err
206}
207
208// SetAdminByUsername implements store.UserStore.
209func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
210 username = strings.ToLower(username)
211 if err := utils.ValidateHandle(username); err != nil {
212 return err
213 }
214
215 query := tx.Rebind(`UPDATE users SET admin = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?)`)
216 _, err := tx.ExecContext(ctx, query, isAdmin, username)
217 return err
218}
219
220// SetUsernameByUsername implements store.UserStore.
221func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
222 username = strings.ToLower(username)
223 if err := utils.ValidateHandle(username); err != nil {
224 return err
225 }
226
227 newUsername = strings.ToLower(newUsername)
228 if err := utils.ValidateHandle(newUsername); err != nil {
229 return err
230 }
231
232 query := tx.Rebind(`UPDATE handles SET handle = ? WHERE handle = ?;`)
233 _, err := tx.ExecContext(ctx, query, newUsername, username)
234 return err
235}
236
237// SetUserPassword implements store.UserStore.
238func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
239 query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
240 _, err := tx.ExecContext(ctx, query, password, userID)
241 return err
242}
243
244// SetUserPasswordByUsername implements store.UserStore.
245func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
246 username = strings.ToLower(username)
247 if err := utils.ValidateHandle(username); err != nil {
248 return err
249 }
250
251 query := tx.Rebind(`UPDATE users SET password = ? WHERE handle_id = (SELECT id FROM handles WHERE handle = ?);`)
252 _, err := tx.ExecContext(ctx, query, password, username)
253 return err
254}
255
256// AddUserEmail implements store.UserStore.
257func (*userStore) AddUserEmail(ctx context.Context, tx db.Handler, userID int64, email string, isPrimary bool) error {
258 query := tx.Rebind(`INSERT INTO user_emails (user_id, email, is_primary, updated_at)
259 VALUES (?, ?, ?, CURRENT_TIMESTAMP);`)
260 _, err := tx.ExecContext(ctx, query, userID, email, isPrimary)
261 return err
262}
263
264// ListUserEmails implements store.UserStore.
265func (*userStore) ListUserEmails(ctx context.Context, tx db.Handler, userID int64) ([]models.UserEmail, error) {
266 var ms []models.UserEmail
267 query := tx.Rebind(`SELECT * FROM user_emails WHERE user_id = ?;`)
268 err := tx.SelectContext(ctx, &ms, query, userID)
269 return ms, err
270}
271
272// UpdateUserEmail implements store.UserStore.
273func (*userStore) UpdateUserEmail(ctx context.Context, tx db.Handler, userID int64, oldEmail string, newEmail string, isPrimary bool) error {
274 query := tx.Rebind(`UPDATE user_emails SET email = ?, is_primary = ?, updated_at = CURRENT_TIMESTAMP WHERE user_id = ? AND email = ?;`)
275 _, err := tx.ExecContext(ctx, query, newEmail, isPrimary, userID, oldEmail)
276 return err
277}
278
279// DeleteUserEmail implements store.UserStore.
280func (*userStore) DeleteUserEmail(ctx context.Context, tx db.Handler, userID int64, email string) error {
281 query := tx.Rebind(`DELETE FROM user_emails WHERE user_id = ? AND email = ?;`)
282 _, err := tx.ExecContext(ctx, query, userID, email)
283 return err
284}