1package database
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/server/db"
  8	"github.com/charmbracelet/soft-serve/server/db/models"
  9	"github.com/charmbracelet/soft-serve/server/sshutils"
 10	"github.com/charmbracelet/soft-serve/server/store"
 11	"github.com/charmbracelet/soft-serve/server/utils"
 12	"golang.org/x/crypto/ssh"
 13)
 14
 15type userStore struct{}
 16
 17var _ store.UserStore = (*userStore)(nil)
 18
 19// AddPublicKeyByUsername implements store.UserStore.
 20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error {
 21	username = strings.ToLower(username)
 22	if err := utils.ValidateUsername(username); err != nil {
 23		return err
 24	}
 25
 26	var userID int64
 27	if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
 28		return err
 29	}
 30
 31	query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 32			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 33	ak := sshutils.MarshalAuthorizedKey(pk)
 34	_, err := tx.ExecContext(ctx, query, userID, ak)
 35
 36	return err
 37}
 38
 39// CreateUser implements store.UserStore.
 40func (*userStore) CreateUser(ctx context.Context, tx *db.Tx, username string, isAdmin bool, pks []ssh.PublicKey) error {
 41	username = strings.ToLower(username)
 42	if err := utils.ValidateUsername(username); err != nil {
 43		return err
 44	}
 45
 46	query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
 47			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 48	result, err := tx.ExecContext(ctx, query, username, isAdmin)
 49	if err != nil {
 50		return err
 51	}
 52
 53	userID, err := result.LastInsertId()
 54	if err != nil {
 55		return err
 56	}
 57
 58	for _, pk := range pks {
 59		query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 60			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 61		ak := sshutils.MarshalAuthorizedKey(pk)
 62		_, err := tx.ExecContext(ctx, query, userID, ak)
 63		if err != nil {
 64			return err
 65		}
 66	}
 67
 68	return nil
 69}
 70
 71// DeleteUserByUsername implements store.UserStore.
 72func (*userStore) DeleteUserByUsername(ctx context.Context, tx *db.Tx, username string) error {
 73	username = strings.ToLower(username)
 74	if err := utils.ValidateUsername(username); err != nil {
 75		return err
 76	}
 77
 78	query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
 79	_, err := tx.ExecContext(ctx, query, username)
 80	return err
 81}
 82
 83// FindUserByPublicKey implements store.UserStore.
 84func (*userStore) FindUserByPublicKey(ctx context.Context, tx *db.Tx, pk ssh.PublicKey) (models.User, error) {
 85	var m models.User
 86	query := tx.Rebind(`SELECT users.*
 87			FROM users
 88			INNER JOIN public_keys ON users.id = public_keys.user_id
 89			WHERE public_keys.public_key = ?;`)
 90	err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
 91	return m, err
 92}
 93
 94// FindUserByUsername implements store.UserStore.
 95func (*userStore) FindUserByUsername(ctx context.Context, tx *db.Tx, username string) (models.User, error) {
 96	username = strings.ToLower(username)
 97	if err := utils.ValidateUsername(username); err != nil {
 98		return models.User{}, err
 99	}
100
101	var m models.User
102	query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
103	err := tx.GetContext(ctx, &m, query, username)
104	return m, err
105}
106
107// GetAllUsers implements store.UserStore.
108func (*userStore) GetAllUsers(ctx context.Context, tx *db.Tx) ([]models.User, error) {
109	var ms []models.User
110	query := tx.Rebind(`SELECT * FROM users;`)
111	err := tx.SelectContext(ctx, &ms, query)
112	return ms, err
113}
114
115// ListPublicKeysByUserID implements store.UserStore..
116func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx *db.Tx, id int64) ([]ssh.PublicKey, error) {
117	var aks []string
118	query := tx.Rebind(`SELECT public_key FROM public_keys
119			WHERE user_id = ?
120			ORDER BY public_keys.id ASC;`)
121	err := tx.SelectContext(ctx, &aks, query, id)
122	if err != nil {
123		return nil, err
124	}
125
126	pks := make([]ssh.PublicKey, len(aks))
127	for i, ak := range aks {
128		pk, _, err := sshutils.ParseAuthorizedKey(ak)
129		if err != nil {
130			return nil, err
131		}
132		pks[i] = pk
133	}
134
135	return pks, nil
136}
137
138// ListPublicKeysByUsername implements store.UserStore.
139func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx *db.Tx, username string) ([]ssh.PublicKey, error) {
140	username = strings.ToLower(username)
141	if err := utils.ValidateUsername(username); err != nil {
142		return nil, err
143	}
144
145	var aks []string
146	query := tx.Rebind(`SELECT public_key FROM public_keys
147			INNER JOIN users ON users.id = public_keys.user_id
148			WHERE users.username = ?
149			ORDER BY public_keys.id ASC;`)
150	err := tx.SelectContext(ctx, &aks, query, username)
151	if err != nil {
152		return nil, err
153	}
154
155	pks := make([]ssh.PublicKey, len(aks))
156	for i, ak := range aks {
157		pk, _, err := sshutils.ParseAuthorizedKey(ak)
158		if err != nil {
159			return nil, err
160		}
161		pks[i] = pk
162	}
163
164	return pks, nil
165}
166
167// RemovePublicKeyByUsername implements store.UserStore.
168func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx *db.Tx, username string, pk ssh.PublicKey) error {
169	username = strings.ToLower(username)
170	if err := utils.ValidateUsername(username); err != nil {
171		return err
172	}
173
174	query := tx.Rebind(`DELETE FROM public_keys
175			WHERE user_id = (SELECT id FROM users WHERE username = ?)
176			AND public_key = ?;`)
177	_, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
178	return err
179}
180
181// SetAdminByUsername implements store.UserStore.
182func (*userStore) SetAdminByUsername(ctx context.Context, tx *db.Tx, username string, isAdmin bool) error {
183	username = strings.ToLower(username)
184	if err := utils.ValidateUsername(username); err != nil {
185		return err
186	}
187
188	query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
189	_, err := tx.ExecContext(ctx, query, isAdmin, username)
190	return err
191}
192
193// SetUsernameByUsername implements store.UserStore.
194func (*userStore) SetUsernameByUsername(ctx context.Context, tx *db.Tx, username string, newUsername string) error {
195	username = strings.ToLower(username)
196	if err := utils.ValidateUsername(username); err != nil {
197		return err
198	}
199
200	newUsername = strings.ToLower(newUsername)
201	if err := utils.ValidateUsername(newUsername); err != nil {
202		return err
203	}
204
205	query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
206	_, err := tx.ExecContext(ctx, query, newUsername, username)
207	return err
208}