1package database
2
3import (
4 "context"
5 "strings"
6
7 "github.com/charmbracelet/soft-serve/server/db"
8 "github.com/charmbracelet/soft-serve/server/db/models"
9 "github.com/charmbracelet/soft-serve/server/sshutils"
10 "github.com/charmbracelet/soft-serve/server/store"
11 "github.com/charmbracelet/soft-serve/server/utils"
12 "golang.org/x/crypto/ssh"
13)
14
15type userStore struct{}
16
17var _ store.UserStore = (*userStore)(nil)
18
19// AddPublicKeyByUsername implements store.UserStore.
20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
21 username = strings.ToLower(username)
22 if err := utils.ValidateUsername(username); err != nil {
23 return err
24 }
25
26 var userID int64
27 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
28 return err
29 }
30
31 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
32 VALUES (?, ?, CURRENT_TIMESTAMP);`)
33 ak := sshutils.MarshalAuthorizedKey(pk)
34 _, err := tx.ExecContext(ctx, query, userID, ak)
35
36 return err
37}
38
39// CreateUser implements store.UserStore.
40func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
41 username = strings.ToLower(username)
42 if err := utils.ValidateUsername(username); err != nil {
43 return err
44 }
45
46 query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
47 VALUES (?, ?, CURRENT_TIMESTAMP);`)
48 result, err := tx.ExecContext(ctx, query, username, isAdmin)
49 if err != nil {
50 return err
51 }
52
53 userID, err := result.LastInsertId()
54 if err != nil {
55 return err
56 }
57
58 for _, pk := range pks {
59 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
60 VALUES (?, ?, CURRENT_TIMESTAMP);`)
61 ak := sshutils.MarshalAuthorizedKey(pk)
62 _, err := tx.ExecContext(ctx, query, userID, ak)
63 if err != nil {
64 return err
65 }
66 }
67
68 return nil
69}
70
71// DeleteUserByUsername implements store.UserStore.
72func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
73 username = strings.ToLower(username)
74 if err := utils.ValidateUsername(username); err != nil {
75 return err
76 }
77
78 query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
79 _, err := tx.ExecContext(ctx, query, username)
80 return err
81}
82
83// GetUserByID implements store.UserStore.
84func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
85 var m models.User
86 query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
87 err := tx.GetContext(ctx, &m, query, id)
88 return m, err
89}
90
91// FindUserByPublicKey implements store.UserStore.
92func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
93 var m models.User
94 query := tx.Rebind(`SELECT users.*
95 FROM users
96 INNER JOIN public_keys ON users.id = public_keys.user_id
97 WHERE public_keys.public_key = ?;`)
98 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
99 return m, err
100}
101
102// FindUserByUsername implements store.UserStore.
103func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
104 username = strings.ToLower(username)
105 if err := utils.ValidateUsername(username); err != nil {
106 return models.User{}, err
107 }
108
109 var m models.User
110 query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
111 err := tx.GetContext(ctx, &m, query, username)
112 return m, err
113}
114
115// FindUserByAccessToken implements store.UserStore.
116func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
117 var m models.User
118 query := tx.Rebind(`SELECT users.*
119 FROM users
120 INNER JOIN access_tokens ON users.id = access_tokens.user_id
121 WHERE access_tokens.token = ?;`)
122 err := tx.GetContext(ctx, &m, query, token)
123 return m, err
124}
125
126// GetAllUsers implements store.UserStore.
127func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
128 var ms []models.User
129 query := tx.Rebind(`SELECT * FROM users;`)
130 err := tx.SelectContext(ctx, &ms, query)
131 return ms, err
132}
133
134// ListPublicKeysByUserID implements store.UserStore..
135func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
136 var aks []string
137 query := tx.Rebind(`SELECT public_key FROM public_keys
138 WHERE user_id = ?
139 ORDER BY public_keys.id ASC;`)
140 err := tx.SelectContext(ctx, &aks, query, id)
141 if err != nil {
142 return nil, err
143 }
144
145 pks := make([]ssh.PublicKey, len(aks))
146 for i, ak := range aks {
147 pk, _, err := sshutils.ParseAuthorizedKey(ak)
148 if err != nil {
149 return nil, err
150 }
151 pks[i] = pk
152 }
153
154 return pks, nil
155}
156
157// ListPublicKeysByUsername implements store.UserStore.
158func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
159 username = strings.ToLower(username)
160 if err := utils.ValidateUsername(username); err != nil {
161 return nil, err
162 }
163
164 var aks []string
165 query := tx.Rebind(`SELECT public_key FROM public_keys
166 INNER JOIN users ON users.id = public_keys.user_id
167 WHERE users.username = ?
168 ORDER BY public_keys.id ASC;`)
169 err := tx.SelectContext(ctx, &aks, query, username)
170 if err != nil {
171 return nil, err
172 }
173
174 pks := make([]ssh.PublicKey, len(aks))
175 for i, ak := range aks {
176 pk, _, err := sshutils.ParseAuthorizedKey(ak)
177 if err != nil {
178 return nil, err
179 }
180 pks[i] = pk
181 }
182
183 return pks, nil
184}
185
186// RemovePublicKeyByUsername implements store.UserStore.
187func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
188 username = strings.ToLower(username)
189 if err := utils.ValidateUsername(username); err != nil {
190 return err
191 }
192
193 query := tx.Rebind(`DELETE FROM public_keys
194 WHERE user_id = (SELECT id FROM users WHERE username = ?)
195 AND public_key = ?;`)
196 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
197 return err
198}
199
200// SetAdminByUsername implements store.UserStore.
201func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
202 username = strings.ToLower(username)
203 if err := utils.ValidateUsername(username); err != nil {
204 return err
205 }
206
207 query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
208 _, err := tx.ExecContext(ctx, query, isAdmin, username)
209 return err
210}
211
212// SetUsernameByUsername implements store.UserStore.
213func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
214 username = strings.ToLower(username)
215 if err := utils.ValidateUsername(username); err != nil {
216 return err
217 }
218
219 newUsername = strings.ToLower(newUsername)
220 if err := utils.ValidateUsername(newUsername); err != nil {
221 return err
222 }
223
224 query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
225 _, err := tx.ExecContext(ctx, query, newUsername, username)
226 return err
227}
228
229// SetUserPassword implements store.UserStore.
230func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
231 query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
232 _, err := tx.ExecContext(ctx, query, password, userID)
233 return err
234}
235
236// SetUserPasswordByUsername implements store.UserStore.
237func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
238 username = strings.ToLower(username)
239 if err := utils.ValidateUsername(username); err != nil {
240 return err
241 }
242
243 query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`)
244 _, err := tx.ExecContext(ctx, query, password, username)
245 return err
246}