1package sqlite
2
3import (
4 "context"
5 "errors"
6 "strings"
7
8 "github.com/charmbracelet/log"
9 "github.com/charmbracelet/soft-serve/server/auth"
10 "github.com/charmbracelet/soft-serve/server/db"
11 "github.com/charmbracelet/soft-serve/server/db/sqlite"
12 "github.com/charmbracelet/soft-serve/server/sshutils"
13 "github.com/charmbracelet/soft-serve/server/utils"
14 "github.com/jmoiron/sqlx"
15 "golang.org/x/crypto/ssh"
16)
17
18// SqliteAuthStore is a sqlite auth store.
19type SqliteAuthStore struct {
20 db db.Database
21 ctx context.Context
22 logger *log.Logger
23}
24
25func init() {
26 auth.Register("sqlite", newAuthStore)
27}
28
29func newAuthStore(ctx context.Context) (auth.Auth, error) {
30 sdb := db.FromContext(ctx)
31 if sdb == nil {
32 return nil, db.ErrNoDatabase
33 }
34
35 if _, ok := sdb.(*sqlite.Sqlite); !ok {
36 return nil, errors.New("database is not a SQLite database")
37 }
38
39 return &SqliteAuthStore{
40 db: sdb,
41 ctx: ctx,
42 logger: log.FromContext(ctx).WithPrefix("sqlite"),
43 }, nil
44}
45
46// Authenticate implements auth.Auth.
47func (d *SqliteAuthStore) Authenticate(ctx context.Context, method auth.AuthMethod) (auth.User, error) {
48 switch m := method.(type) {
49 case auth.PublicKey:
50 u, err := d.UserByPublicKey(ctx, m)
51 if err != nil {
52 return nil, err
53 }
54
55 return u, nil
56 default:
57 return nil, auth.ErrUnsupportedAuthMethod
58 }
59}
60
61// CreateUser creates a new user.
62func (d *SqliteAuthStore) CreateUser(ctx context.Context, username string, opts UserOptions) (auth.User, error) {
63 username = strings.ToLower(username)
64 if err := utils.ValidateUsername(username); err != nil {
65 return nil, err
66 }
67
68 var user *User
69 if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
70 stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
71 if err != nil {
72 return err
73 }
74
75 defer stmt.Close() // nolint: errcheck
76 r, err := stmt.Exec(username, opts.Admin)
77 if err != nil {
78 return err
79 }
80
81 if len(opts.PublicKeys) > 0 {
82 userID, err := r.LastInsertId()
83 if err != nil {
84 d.logger.Error("error getting last insert id")
85 return err
86 }
87
88 for _, pk := range opts.PublicKeys {
89 stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
90 VALUES (?, ?, CURRENT_TIMESTAMP);`)
91 if err != nil {
92 return err
93 }
94
95 defer stmt.Close() // nolint: errcheck
96 if _, err := stmt.Exec(userID, sshutils.MarshalAuthorizedKey(pk)); err != nil {
97 return err
98 }
99 }
100 }
101
102 user = &User{
103 username: username,
104 isAdmin: opts.Admin,
105 publicKeys: opts.PublicKeys,
106 }
107 return nil
108 }); err != nil {
109 return nil, sqlite.WrapDbErr(err)
110 }
111
112 return user, nil
113}
114
115// DeleteUser deletes a user.
116func (d *SqliteAuthStore) DeleteUser(ctx context.Context, username string) error {
117 username = strings.ToLower(username)
118 if err := utils.ValidateUsername(username); err != nil {
119 return err
120 }
121
122 return sqlite.WrapDbErr(
123 sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
124 _, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
125 return err
126 }),
127 )
128}
129
130// SetUsername sets the username of a user.
131func (d *SqliteAuthStore) SetUsername(ctx context.Context, username string, newUsername string) error {
132 username = strings.ToLower(username)
133 if err := utils.ValidateUsername(username); err != nil {
134 return err
135 }
136
137 return sqlite.WrapDbErr(
138 sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
139 _, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
140 return err
141 }),
142 )
143}
144
145// SetAdmin sets the admin flag of a user.
146func (d *SqliteAuthStore) SetAdmin(ctx context.Context, username string, admin bool) error {
147 username = strings.ToLower(username)
148 if err := utils.ValidateUsername(username); err != nil {
149 return err
150 }
151
152 return sqlite.WrapDbErr(
153 sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
154 _, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
155 return err
156 }),
157 )
158}
159
160// User finds a user by username.
161func (d *SqliteAuthStore) User(ctx context.Context, username string) (auth.User, error) {
162 return d.user(ctx, username, nil)
163}
164
165// UserByPublicKey finds a user by public key.
166func (d *SqliteAuthStore) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (auth.User, error) {
167 return d.user(ctx, "", pk)
168}
169
170func (d *SqliteAuthStore) user(ctx context.Context, username string, pk ssh.PublicKey) (*User, error) {
171 if username == "" && pk == nil {
172 return nil, errors.New("username or public key must be provided")
173 }
174
175 user := &User{
176 username: username,
177 publicKeys: make([]ssh.PublicKey, 0),
178 }
179 if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
180 if username == "" {
181 row := tx.QueryRow(`SELECT user.username, user.admin
182 FROM public_key
183 INNER JOIN user ON user.id = public_key.user_id
184 WHERE public_key.public_key = ?;`, sshutils.MarshalAuthorizedKey(pk))
185 if err := row.Scan(&user.username, &user.isAdmin); err != nil {
186 return err
187 }
188 } else {
189 row := tx.QueryRow(`SELECT user.admin
190 FROM user
191 WHERE user.username = ?;`, username)
192 if err := row.Scan(&user.isAdmin); err != nil {
193 return err
194 }
195 }
196
197 rows, err := tx.Query(`SELECT public_key.public_key
198 FROM public_key
199 INNER JOIN user ON user.id = public_key.user_id
200 WHERE user.username = ?;`, user.username)
201 if err != nil {
202 return err
203 }
204
205 for rows.Next() {
206 var ak string
207 if err := rows.Scan(&ak); err != nil {
208 return err
209 }
210
211 if pk, _, err := sshutils.ParseAuthorizedKey(ak); err == nil {
212 user.publicKeys = append(user.publicKeys, pk)
213 }
214 }
215
216 return nil
217 }); err != nil {
218 return nil, sqlite.WrapDbErr(err)
219 }
220
221 return user, nil
222}
223
224// Users returns all users.
225//
226// TODO: pagination
227func (d *SqliteAuthStore) Users(ctx context.Context) ([]string, error) {
228 var users []string
229 if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
230 return tx.Select(&users, "SELECT username FROM user")
231 }); err != nil {
232 return nil, sqlite.WrapDbErr(err)
233 }
234
235 return users, nil
236}
237
238// AddPublicKey adds a public key to a user.
239//
240// It implements backend.Backend.
241func (d *SqliteAuthStore) AddPublicKey(ctx context.Context, username string, pk ssh.PublicKey) error {
242 username = strings.ToLower(username)
243 if err := utils.ValidateUsername(username); err != nil {
244 return err
245 }
246
247 return sqlite.WrapDbErr(
248 sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
249 var userID int
250 if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
251 return err
252 }
253
254 _, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
255 VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, sshutils.MarshalAuthorizedKey(pk))
256 return err
257 }),
258 )
259}
260
261// RemovePublicKey removes a public key from a user.
262//
263// It implements backend.Backend.
264func (d *SqliteAuthStore) RemovePublicKey(ctx context.Context, username string, pk ssh.PublicKey) error {
265 return sqlite.WrapDbErr(
266 sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
267 _, err := tx.Exec(`DELETE FROM public_key
268 WHERE user_id = (SELECT id FROM user WHERE username = ?)
269 AND public_key = ?;`, username, sshutils.MarshalAuthorizedKey(pk))
270 return err
271 }),
272 )
273}
274
275// ListPublicKeys lists the public keys of a user.
276func (d *SqliteAuthStore) ListPublicKeys(ctx context.Context, username string) ([]ssh.PublicKey, error) {
277 username = strings.ToLower(username)
278 if err := utils.ValidateUsername(username); err != nil {
279 return nil, err
280 }
281
282 keys := make([]ssh.PublicKey, 0)
283 if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
284 var keyStrings []string
285 if err := tx.Select(&keyStrings, `SELECT public_key
286 FROM public_key
287 INNER JOIN user ON user.id = public_key.user_id
288 WHERE user.username = ?;`, username); err != nil {
289 return err
290 }
291
292 for _, keyString := range keyStrings {
293 key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
294 if err != nil {
295 return err
296 }
297 keys = append(keys, key)
298 }
299
300 return nil
301 }); err != nil {
302 return nil, sqlite.WrapDbErr(err)
303 }
304
305 return keys, nil
306}