1package database
2
3import (
4 "context"
5 "strings"
6
7 "github.com/charmbracelet/soft-serve/server/db"
8 "github.com/charmbracelet/soft-serve/server/db/models"
9 "github.com/charmbracelet/soft-serve/server/sshutils"
10 "github.com/charmbracelet/soft-serve/server/store"
11 "github.com/charmbracelet/soft-serve/server/utils"
12 "golang.org/x/crypto/ssh"
13)
14
15type userStore struct{}
16
17var _ store.UserStore = (*userStore)(nil)
18
19// AddPublicKeyByUsername implements store.UserStore.
20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error {
21 username = strings.ToLower(username)
22 if err := utils.ValidateUsername(username); err != nil {
23 return err
24 }
25
26 var userID int64
27 if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
28 return err
29 }
30
31 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
32 VALUES (?, ?, CURRENT_TIMESTAMP);`)
33 ak := sshutils.MarshalAuthorizedKey(pk)
34 _, err := tx.ExecContext(ctx, query, userID, ak)
35
36 return err
37}
38
39// CreateUser implements store.UserStore.
40func (*userStore) CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error {
41 username = strings.ToLower(username)
42 if err := utils.ValidateUsername(username); err != nil {
43 return err
44 }
45
46 query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
47 VALUES (?, ?, CURRENT_TIMESTAMP);`)
48 result, err := tx.ExecContext(ctx, query, username, isAdmin)
49 if err != nil {
50 return err
51 }
52
53 userID, err := result.LastInsertId()
54 if err != nil {
55 return err
56 }
57
58 for _, pk := range pks {
59 query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
60 VALUES (?, ?, CURRENT_TIMESTAMP);`)
61 ak := sshutils.MarshalAuthorizedKey(pk)
62 _, err := tx.ExecContext(ctx, query, userID, ak)
63 if err != nil {
64 return err
65 }
66 }
67
68 return nil
69}
70
71// DeleteUserByUsername implements store.UserStore.
72func (*userStore) DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error {
73 username = strings.ToLower(username)
74 if err := utils.ValidateUsername(username); err != nil {
75 return err
76 }
77
78 query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
79 _, err := tx.ExecContext(ctx, query, username)
80 return err
81}
82
83// FindUserByPublicKey implements store.UserStore.
84func (*userStore) FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) {
85 var m models.User
86 query := tx.Rebind(`SELECT users.*
87 FROM users
88 INNER JOIN public_keys ON users.id = public_keys.user_id
89 WHERE public_keys.public_key = ?;`)
90 err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
91 return m, err
92}
93
94// FindUserByUsername implements store.UserStore.
95func (*userStore) FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) {
96 username = strings.ToLower(username)
97 if err := utils.ValidateUsername(username); err != nil {
98 return models.User{}, err
99 }
100
101 var m models.User
102 query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
103 err := tx.GetContext(ctx, &m, query, username)
104 return m, err
105}
106
107// GetAllUsers implements store.UserStore.
108func (*userStore) GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) {
109 var ms []models.User
110 query := tx.Rebind(`SELECT * FROM users;`)
111 err := tx.SelectContext(ctx, &ms, query)
112 return ms, err
113}
114
115// ListPublicKeysByUserID implements store.UserStore..
116func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) {
117 var aks []string
118 query := tx.Rebind(`SELECT public_key FROM public_keys
119 WHERE user_id = ?
120 ORDER BY public_keys.id ASC;`)
121 err := tx.SelectContext(ctx, &aks, query, id)
122 if err != nil {
123 return nil, err
124 }
125
126 pks := make([]ssh.PublicKey, len(aks))
127 for i, ak := range aks {
128 pk, _, err := sshutils.ParseAuthorizedKey(ak)
129 if err != nil {
130 return nil, err
131 }
132 pks[i] = pk
133 }
134
135 return pks, nil
136}
137
138// ListPublicKeysByUsername implements store.UserStore.
139func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) {
140 username = strings.ToLower(username)
141 if err := utils.ValidateUsername(username); err != nil {
142 return nil, err
143 }
144
145 var aks []string
146 query := tx.Rebind(`SELECT public_key FROM public_keys
147 INNER JOIN users ON users.id = public_keys.user_id
148 WHERE users.username = ?
149 ORDER BY public_keys.id ASC;`)
150 err := tx.SelectContext(ctx, &aks, query, username)
151 if err != nil {
152 return nil, err
153 }
154
155 pks := make([]ssh.PublicKey, len(aks))
156 for i, ak := range aks {
157 pk, _, err := sshutils.ParseAuthorizedKey(ak)
158 if err != nil {
159 return nil, err
160 }
161 pks[i] = pk
162 }
163
164 return pks, nil
165}
166
167// RemovePublicKeyByUsername implements store.UserStore.
168func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error {
169 username = strings.ToLower(username)
170 if err := utils.ValidateUsername(username); err != nil {
171 return err
172 }
173
174 query := tx.Rebind(`DELETE FROM public_keys
175 WHERE user_id = (SELECT id FROM users WHERE username = ?)
176 AND public_key = ?;`)
177 _, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
178 return err
179}
180
181// SetAdminByUsername implements store.UserStore.
182func (*userStore) SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error {
183 username = strings.ToLower(username)
184 if err := utils.ValidateUsername(username); err != nil {
185 return err
186 }
187
188 query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
189 _, err := tx.ExecContext(ctx, query, isAdmin, username)
190 return err
191}
192
193// SetUsernameByUsername implements store.UserStore.
194func (*userStore) SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error {
195 username = strings.ToLower(username)
196 if err := utils.ValidateUsername(username); err != nil {
197 return err
198 }
199
200 newUsername = strings.ToLower(newUsername)
201 if err := utils.ValidateUsername(newUsername); err != nil {
202 return err
203 }
204
205 query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
206 _, err := tx.ExecContext(ctx, query, newUsername, username)
207 return err
208}