sqlite.go

  1package sqlite
  2
  3import (
  4	"context"
  5	"errors"
  6	"strings"
  7
  8	"github.com/charmbracelet/log"
  9	"github.com/charmbracelet/soft-serve/server/auth"
 10	"github.com/charmbracelet/soft-serve/server/db"
 11	"github.com/charmbracelet/soft-serve/server/db/sqlite"
 12	"github.com/charmbracelet/soft-serve/server/sshutils"
 13	"github.com/charmbracelet/soft-serve/server/utils"
 14	"github.com/jmoiron/sqlx"
 15	"golang.org/x/crypto/ssh"
 16)
 17
 18// SqliteAuthStore is a sqlite auth store.
 19type SqliteAuthStore struct {
 20	db     db.Database
 21	ctx    context.Context
 22	logger *log.Logger
 23}
 24
 25func init() {
 26	auth.Register("sqlite", newAuthStore)
 27}
 28
 29func newAuthStore(ctx context.Context) (auth.Auth, error) {
 30	sdb := db.FromContext(ctx)
 31	if sdb == nil {
 32		return nil, db.ErrNoDatabase
 33	}
 34
 35	if _, ok := sdb.(*sqlite.Sqlite); !ok {
 36		return nil, errors.New("database is not a SQLite database")
 37	}
 38
 39	return &SqliteAuthStore{
 40		db:     sdb,
 41		ctx:    ctx,
 42		logger: log.FromContext(ctx).WithPrefix("sqlite"),
 43	}, nil
 44}
 45
 46// Authenticate implements auth.Auth.
 47func (d *SqliteAuthStore) Authenticate(ctx context.Context, method auth.AuthMethod) (auth.User, error) {
 48	switch m := method.(type) {
 49	case auth.PublicKey:
 50		u, err := d.UserByPublicKey(ctx, m)
 51		if err != nil {
 52			return nil, err
 53		}
 54
 55		return u, nil
 56	default:
 57		return nil, auth.ErrUnsupportedAuthMethod
 58	}
 59}
 60
 61// CreateUser creates a new user.
 62func (d *SqliteAuthStore) CreateUser(ctx context.Context, username string, opts UserOptions) (auth.User, error) {
 63	username = strings.ToLower(username)
 64	if err := utils.ValidateUsername(username); err != nil {
 65		return nil, err
 66	}
 67
 68	var user *User
 69	if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
 70		stmt, err := tx.Prepare("INSERT INTO user (username, admin, updated_at) VALUES (?, ?, CURRENT_TIMESTAMP);")
 71		if err != nil {
 72			return err
 73		}
 74
 75		defer stmt.Close() // nolint: errcheck
 76		r, err := stmt.Exec(username, opts.Admin)
 77		if err != nil {
 78			return err
 79		}
 80
 81		if len(opts.PublicKeys) > 0 {
 82			userID, err := r.LastInsertId()
 83			if err != nil {
 84				d.logger.Error("error getting last insert id")
 85				return err
 86			}
 87
 88			for _, pk := range opts.PublicKeys {
 89				stmt, err := tx.Prepare(`INSERT INTO public_key (user_id, public_key, updated_at)
 90					VALUES (?, ?, CURRENT_TIMESTAMP);`)
 91				if err != nil {
 92					return err
 93				}
 94
 95				defer stmt.Close() // nolint: errcheck
 96				if _, err := stmt.Exec(userID, sshutils.MarshalAuthorizedKey(pk)); err != nil {
 97					return err
 98				}
 99			}
100		}
101
102		user = &User{
103			username:   username,
104			isAdmin:    opts.Admin,
105			publicKeys: opts.PublicKeys,
106		}
107		return nil
108	}); err != nil {
109		return nil, sqlite.WrapDbErr(err)
110	}
111
112	return user, nil
113}
114
115// DeleteUser deletes a user.
116func (d *SqliteAuthStore) DeleteUser(ctx context.Context, username string) error {
117	username = strings.ToLower(username)
118	if err := utils.ValidateUsername(username); err != nil {
119		return err
120	}
121
122	return sqlite.WrapDbErr(
123		sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
124			_, err := tx.Exec("DELETE FROM user WHERE username = ?", username)
125			return err
126		}),
127	)
128}
129
130// SetUsername sets the username of a user.
131func (d *SqliteAuthStore) SetUsername(ctx context.Context, username string, newUsername string) error {
132	username = strings.ToLower(username)
133	if err := utils.ValidateUsername(username); err != nil {
134		return err
135	}
136
137	return sqlite.WrapDbErr(
138		sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
139			_, err := tx.Exec("UPDATE user SET username = ? WHERE username = ?", newUsername, username)
140			return err
141		}),
142	)
143}
144
145// SetAdmin sets the admin flag of a user.
146func (d *SqliteAuthStore) SetAdmin(ctx context.Context, username string, admin bool) error {
147	username = strings.ToLower(username)
148	if err := utils.ValidateUsername(username); err != nil {
149		return err
150	}
151
152	return sqlite.WrapDbErr(
153		sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
154			_, err := tx.Exec("UPDATE user SET admin = ? WHERE username = ?", admin, username)
155			return err
156		}),
157	)
158}
159
160// User finds a user by username.
161func (d *SqliteAuthStore) User(ctx context.Context, username string) (auth.User, error) {
162	return d.user(ctx, username, nil)
163}
164
165// UserByPublicKey finds a user by public key.
166func (d *SqliteAuthStore) UserByPublicKey(ctx context.Context, pk ssh.PublicKey) (auth.User, error) {
167	return d.user(ctx, "", pk)
168}
169
170func (d *SqliteAuthStore) user(ctx context.Context, username string, pk ssh.PublicKey) (*User, error) {
171	if username == "" && pk == nil {
172		return nil, errors.New("username or public key must be provided")
173	}
174
175	user := &User{
176		username:   username,
177		publicKeys: make([]ssh.PublicKey, 0),
178	}
179	if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
180		if username == "" {
181			row := tx.QueryRow(`SELECT user.username, user.admin
182			FROM public_key
183			INNER JOIN user ON user.id = public_key.user_id
184			WHERE public_key.public_key = ?;`, sshutils.MarshalAuthorizedKey(pk))
185			if err := row.Scan(&user.username, &user.isAdmin); err != nil {
186				return err
187			}
188		} else {
189			row := tx.QueryRow(`SELECT user.admin
190			FROM user
191			WHERE user.username = ?;`, username)
192			if err := row.Scan(&user.isAdmin); err != nil {
193				return err
194			}
195		}
196
197		rows, err := tx.Query(`SELECT public_key.public_key
198			FROM public_key
199			INNER JOIN user ON user.id = public_key.user_id
200			WHERE user.username = ?;`, user.username)
201		if err != nil {
202			return err
203		}
204
205		for rows.Next() {
206			var ak string
207			if err := rows.Scan(&ak); err != nil {
208				return err
209			}
210
211			if pk, _, err := sshutils.ParseAuthorizedKey(ak); err == nil {
212				user.publicKeys = append(user.publicKeys, pk)
213			}
214		}
215
216		return nil
217	}); err != nil {
218		return nil, sqlite.WrapDbErr(err)
219	}
220
221	return user, nil
222}
223
224// Users returns all users.
225//
226// TODO: pagination
227func (d *SqliteAuthStore) Users(ctx context.Context) ([]string, error) {
228	var users []string
229	if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
230		return tx.Select(&users, "SELECT username FROM user")
231	}); err != nil {
232		return nil, sqlite.WrapDbErr(err)
233	}
234
235	return users, nil
236}
237
238// AddPublicKey adds a public key to a user.
239//
240// It implements backend.Backend.
241func (d *SqliteAuthStore) AddPublicKey(ctx context.Context, username string, pk ssh.PublicKey) error {
242	username = strings.ToLower(username)
243	if err := utils.ValidateUsername(username); err != nil {
244		return err
245	}
246
247	return sqlite.WrapDbErr(
248		sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
249			var userID int
250			if err := tx.Get(&userID, "SELECT id FROM user WHERE username = ?", username); err != nil {
251				return err
252			}
253
254			_, err := tx.Exec(`INSERT INTO public_key (user_id, public_key, updated_at)
255			VALUES (?, ?, CURRENT_TIMESTAMP);`, userID, sshutils.MarshalAuthorizedKey(pk))
256			return err
257		}),
258	)
259}
260
261// RemovePublicKey removes a public key from a user.
262//
263// It implements backend.Backend.
264func (d *SqliteAuthStore) RemovePublicKey(ctx context.Context, username string, pk ssh.PublicKey) error {
265	return sqlite.WrapDbErr(
266		sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
267			_, err := tx.Exec(`DELETE FROM public_key
268			WHERE user_id = (SELECT id FROM user WHERE username = ?)
269			AND public_key = ?;`, username, sshutils.MarshalAuthorizedKey(pk))
270			return err
271		}),
272	)
273}
274
275// ListPublicKeys lists the public keys of a user.
276func (d *SqliteAuthStore) ListPublicKeys(ctx context.Context, username string) ([]ssh.PublicKey, error) {
277	username = strings.ToLower(username)
278	if err := utils.ValidateUsername(username); err != nil {
279		return nil, err
280	}
281
282	keys := make([]ssh.PublicKey, 0)
283	if err := sqlite.WrapTx(d.db.DBx(), ctx, func(tx *sqlx.Tx) error {
284		var keyStrings []string
285		if err := tx.Select(&keyStrings, `SELECT public_key
286			FROM public_key
287			INNER JOIN user ON user.id = public_key.user_id
288			WHERE user.username = ?;`, username); err != nil {
289			return err
290		}
291
292		for _, keyString := range keyStrings {
293			key, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyString))
294			if err != nil {
295				return err
296			}
297			keys = append(keys, key)
298		}
299
300		return nil
301	}); err != nil {
302		return nil, sqlite.WrapDbErr(err)
303	}
304
305	return keys, nil
306}