1package database
2
3import (
4 "context"
5 "strings"
6
7 "github.com/charmbracelet/soft-serve/server/db"
8 "github.com/charmbracelet/soft-serve/server/db/models"
9 "github.com/charmbracelet/soft-serve/server/sshutils"
10 "github.com/charmbracelet/soft-serve/server/store"
11 "github.com/charmbracelet/soft-serve/server/utils"
12 "golang.org/x/crypto/ssh"
13)
14
15type userStore struct{}
16
17var _ store.UserStore = (*userStore)(nil)
18
19// AddPublicKeyByUsername implements store.UserStore.
20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
21 username = strings.ToLower(username)
22 if err := utils.ValidateUsername(username); err != nil {
23 return err
24 }
25
26 var userID int64
27 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
28 return err
29 }
30
31 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
32 VALUES (?, ?, CURRENT_TIMESTAMP);`)
33 ak := sshutils.MarshalAuthorizedKey(pk)
34 _, err := tx.ExecContext(ctx, query, userID, ak)
35
36 return err
37}
38
39// CreateUser implements store.UserStore.
40func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
41 username = strings.ToLower(username)
42 if err := utils.ValidateUsername(username); err != nil {
43 return err
44 }
45
46 query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
47 VALUES (?, ?, CURRENT_TIMESTAMP);`)
48 result, err := tx.ExecContext(ctx, query, username, isAdmin)
49 if err != nil {
50 return err
51 }
52
53 userID, err := result.LastInsertId()
54 if err != nil {
55 return err
56 }
57
58 for _, pk := range pks {
59 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
60 VALUES (?, ?, CURRENT_TIMESTAMP);`)
61 ak := sshutils.MarshalAuthorizedKey(pk)
62 _, err := tx.ExecContext(ctx, query, userID, ak)
63 if err != nil {
64 return err
65 }
66 }
67
68 return nil
69}
70
71// DeleteUserByUsername implements store.UserStore.
72func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
73 username = strings.ToLower(username)
74 if err := utils.ValidateUsername(username); err != nil {
75 return err
76 }
77
78 query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
79 _, err := tx.ExecContext(ctx, query, username)
80 return err
81}
82
83// GetUserByID implements store.UserStore.
84func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
85 var m models.User
86 query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
87 err := tx.GetContext(ctx, &m, query, id)
88 return m, err
89}
90
91// FindUserByPublicKey implements store.UserStore.
92func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
93 var m models.User
94 query := tx.Rebind(`SELECT users.*
95 FROM users
96 INNER JOIN public_keys ON users.id = public_keys.user_id
97 WHERE public_keys.public_key = ?;`)
98 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
99 return m, err
100}
101
102// FindUserByUsername implements store.UserStore.
103func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
104 username = strings.ToLower(username)
105 if err := utils.ValidateUsername(username); err != nil {
106 return models.User{}, err
107 }
108
109 var m models.User
110 query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
111 err := tx.GetContext(ctx, &m, query, username)
112 return m, err
113}
114
115// GetAllUsers implements store.UserStore.
116func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
117 var ms []models.User
118 query := tx.Rebind(`SELECT * FROM users;`)
119 err := tx.SelectContext(ctx, &ms, query)
120 return ms, err
121}
122
123// ListPublicKeysByUserID implements store.UserStore..
124func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
125 var aks []string
126 query := tx.Rebind(`SELECT public_key FROM public_keys
127 WHERE user_id = ?
128 ORDER BY public_keys.id ASC;`)
129 err := tx.SelectContext(ctx, &aks, query, id)
130 if err != nil {
131 return nil, err
132 }
133
134 pks := make([]ssh.PublicKey, len(aks))
135 for i, ak := range aks {
136 pk, _, err := sshutils.ParseAuthorizedKey(ak)
137 if err != nil {
138 return nil, err
139 }
140 pks[i] = pk
141 }
142
143 return pks, nil
144}
145
146// ListPublicKeysByUsername implements store.UserStore.
147func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
148 username = strings.ToLower(username)
149 if err := utils.ValidateUsername(username); err != nil {
150 return nil, err
151 }
152
153 var aks []string
154 query := tx.Rebind(`SELECT public_key FROM public_keys
155 INNER JOIN users ON users.id = public_keys.user_id
156 WHERE users.username = ?
157 ORDER BY public_keys.id ASC;`)
158 err := tx.SelectContext(ctx, &aks, query, username)
159 if err != nil {
160 return nil, err
161 }
162
163 pks := make([]ssh.PublicKey, len(aks))
164 for i, ak := range aks {
165 pk, _, err := sshutils.ParseAuthorizedKey(ak)
166 if err != nil {
167 return nil, err
168 }
169 pks[i] = pk
170 }
171
172 return pks, nil
173}
174
175// RemovePublicKeyByUsername implements store.UserStore.
176func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
177 username = strings.ToLower(username)
178 if err := utils.ValidateUsername(username); err != nil {
179 return err
180 }
181
182 query := tx.Rebind(`DELETE FROM public_keys
183 WHERE user_id = (SELECT id FROM users WHERE username = ?)
184 AND public_key = ?;`)
185 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
186 return err
187}
188
189// SetAdminByUsername implements store.UserStore.
190func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
191 username = strings.ToLower(username)
192 if err := utils.ValidateUsername(username); err != nil {
193 return err
194 }
195
196 query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
197 _, err := tx.ExecContext(ctx, query, isAdmin, username)
198 return err
199}
200
201// SetUsernameByUsername implements store.UserStore.
202func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
203 username = strings.ToLower(username)
204 if err := utils.ValidateUsername(username); err != nil {
205 return err
206 }
207
208 newUsername = strings.ToLower(newUsername)
209 if err := utils.ValidateUsername(newUsername); err != nil {
210 return err
211 }
212
213 query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
214 _, err := tx.ExecContext(ctx, query, newUsername, username)
215 return err
216}
217
218// SetUserPassword implements store.UserStore.
219func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
220 query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
221 _, err := tx.ExecContext(ctx, query, password, userID)
222 return err
223}
224
225// SetUserPasswordByUsername implements store.UserStore.
226func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
227 username = strings.ToLower(username)
228 if err := utils.ValidateUsername(username); err != nil {
229 return err
230 }
231
232 query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`)
233 _, err := tx.ExecContext(ctx, query, password, username)
234 return err
235}