1package database
2
3import (
4 "context"
5 "strings"
6
7 "github.com/charmbracelet/soft-serve/pkg/db"
8 "github.com/charmbracelet/soft-serve/pkg/db/models"
9 "github.com/charmbracelet/soft-serve/pkg/sshutils"
10 "github.com/charmbracelet/soft-serve/pkg/store"
11 "github.com/charmbracelet/soft-serve/pkg/utils"
12 "golang.org/x/crypto/ssh"
13)
14
15type userStore struct{}
16
17var _ store.UserStore = (*userStore)(nil)
18
19// AddPublicKeyByUsername implements store.UserStore.
20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
21 username = strings.ToLower(username)
22 if err := utils.ValidateUsername(username); err != nil {
23 return err //nolint:wrapcheck
24 }
25
26 var userID int64
27 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
28 return err //nolint:wrapcheck
29 }
30
31 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
32 VALUES (?, ?, CURRENT_TIMESTAMP);`)
33 ak := sshutils.MarshalAuthorizedKey(pk)
34 _, err := tx.ExecContext(ctx, query, userID, ak)
35
36 return err //nolint:wrapcheck
37}
38
39// CreateUser implements store.UserStore.
40func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
41 username = strings.ToLower(username)
42 if err := utils.ValidateUsername(username); err != nil {
43 return err //nolint:wrapcheck
44 }
45
46 query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
47 VALUES (?, ?, CURRENT_TIMESTAMP) RETURNING id;`)
48
49 var userID int64
50 if err := tx.GetContext(ctx, &userID, query, username, isAdmin); err != nil {
51 return err //nolint:wrapcheck
52 }
53
54 for _, pk := range pks {
55 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
56 VALUES (?, ?, CURRENT_TIMESTAMP);`)
57 ak := sshutils.MarshalAuthorizedKey(pk)
58 _, err := tx.ExecContext(ctx, query, userID, ak)
59 if err != nil {
60 return err //nolint:wrapcheck
61 }
62 }
63
64 return nil
65}
66
67// DeleteUserByUsername implements store.UserStore.
68func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
69 username = strings.ToLower(username)
70 if err := utils.ValidateUsername(username); err != nil {
71 return err //nolint:wrapcheck
72 }
73
74 query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
75 _, err := tx.ExecContext(ctx, query, username)
76 return err //nolint:wrapcheck
77}
78
79// GetUserByID implements store.UserStore.
80func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
81 var m models.User
82 query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
83 err := tx.GetContext(ctx, &m, query, id)
84 return m, err //nolint:wrapcheck
85}
86
87// FindUserByPublicKey implements store.UserStore.
88func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
89 var m models.User
90 query := tx.Rebind(`SELECT users.*
91 FROM users
92 INNER JOIN public_keys ON users.id = public_keys.user_id
93 WHERE public_keys.public_key = ?;`)
94 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
95 return m, err //nolint:wrapcheck
96}
97
98// FindUserByUsername implements store.UserStore.
99func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
100 username = strings.ToLower(username)
101 if err := utils.ValidateUsername(username); err != nil {
102 return models.User{}, err //nolint:wrapcheck
103 }
104
105 var m models.User
106 query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
107 err := tx.GetContext(ctx, &m, query, username)
108 return m, err //nolint:wrapcheck
109}
110
111// FindUserByAccessToken implements store.UserStore.
112func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
113 var m models.User
114 query := tx.Rebind(`SELECT users.*
115 FROM users
116 INNER JOIN access_tokens ON users.id = access_tokens.user_id
117 WHERE access_tokens.token = ?;`)
118 err := tx.GetContext(ctx, &m, query, token)
119 return m, err //nolint:wrapcheck
120}
121
122// GetAllUsers implements store.UserStore.
123func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
124 var ms []models.User
125 query := tx.Rebind(`SELECT * FROM users;`)
126 err := tx.SelectContext(ctx, &ms, query)
127 return ms, err //nolint:wrapcheck
128}
129
130// ListPublicKeysByUserID implements store.UserStore..
131func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
132 var aks []string
133 query := tx.Rebind(`SELECT public_key FROM public_keys
134 WHERE user_id = ?
135 ORDER BY public_keys.id ASC;`)
136 err := tx.SelectContext(ctx, &aks, query, id)
137 if err != nil {
138 return nil, err //nolint:wrapcheck
139 }
140
141 pks := make([]ssh.PublicKey, len(aks))
142 for i, ak := range aks {
143 pk, _, err := sshutils.ParseAuthorizedKey(ak)
144 if err != nil {
145 return nil, err //nolint:wrapcheck
146 }
147 pks[i] = pk
148 }
149
150 return pks, nil
151}
152
153// ListPublicKeysByUsername implements store.UserStore.
154func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
155 username = strings.ToLower(username)
156 if err := utils.ValidateUsername(username); err != nil {
157 return nil, err //nolint:wrapcheck
158 }
159
160 var aks []string
161 query := tx.Rebind(`SELECT public_key FROM public_keys
162 INNER JOIN users ON users.id = public_keys.user_id
163 WHERE users.username = ?
164 ORDER BY public_keys.id ASC;`)
165 err := tx.SelectContext(ctx, &aks, query, username)
166 if err != nil {
167 return nil, err //nolint:wrapcheck
168 }
169
170 pks := make([]ssh.PublicKey, len(aks))
171 for i, ak := range aks {
172 pk, _, err := sshutils.ParseAuthorizedKey(ak)
173 if err != nil {
174 return nil, err //nolint:wrapcheck
175 }
176 pks[i] = pk
177 }
178
179 return pks, nil
180}
181
182// RemovePublicKeyByUsername implements store.UserStore.
183func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
184 username = strings.ToLower(username)
185 if err := utils.ValidateUsername(username); err != nil {
186 return err //nolint:wrapcheck
187 }
188
189 query := tx.Rebind(`DELETE FROM public_keys
190 WHERE user_id = (SELECT id FROM users WHERE username = ?)
191 AND public_key = ?;`)
192 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
193 return err //nolint:wrapcheck
194}
195
196// SetAdminByUsername implements store.UserStore.
197func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
198 username = strings.ToLower(username)
199 if err := utils.ValidateUsername(username); err != nil {
200 return err //nolint:wrapcheck
201 }
202
203 query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
204 _, err := tx.ExecContext(ctx, query, isAdmin, username)
205 return err //nolint:wrapcheck
206}
207
208// SetUsernameByUsername implements store.UserStore.
209func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
210 username = strings.ToLower(username)
211 if err := utils.ValidateUsername(username); err != nil {
212 return err //nolint:wrapcheck
213 }
214
215 newUsername = strings.ToLower(newUsername)
216 if err := utils.ValidateUsername(newUsername); err != nil {
217 return err //nolint:wrapcheck
218 }
219
220 query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
221 _, err := tx.ExecContext(ctx, query, newUsername, username)
222 return err //nolint:wrapcheck
223}
224
225// SetUserPassword implements store.UserStore.
226func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
227 query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
228 _, err := tx.ExecContext(ctx, query, password, userID)
229 return err //nolint:wrapcheck
230}
231
232// SetUserPasswordByUsername implements store.UserStore.
233func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
234 username = strings.ToLower(username)
235 if err := utils.ValidateUsername(username); err != nil {
236 return err //nolint:wrapcheck
237 }
238
239 query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`)
240 _, err := tx.ExecContext(ctx, query, password, username)
241 return err //nolint:wrapcheck
242}