1package database
  2
  3import (
  4	"context"
  5	"strings"
  6
  7	"github.com/charmbracelet/soft-serve/server/db"
  8	"github.com/charmbracelet/soft-serve/server/db/models"
  9	"github.com/charmbracelet/soft-serve/server/sshutils"
 10	"github.com/charmbracelet/soft-serve/server/store"
 11	"github.com/charmbracelet/soft-serve/server/utils"
 12	"golang.org/x/crypto/ssh"
 13)
 14
 15type userStore struct{}
 16
 17var _ store.UserStore = (*userStore)(nil)
 18
 19// AddPublicKeyByUsername implements store.UserStore.
 20func (*userStore) AddPublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
 21	username = strings.ToLower(username)
 22	if err := utils.ValidateUsername(username); err != nil {
 23		return err
 24	}
 25
 26	var userID int64
 27	if err := tx.GetContext(ctx, &userID, tx.Rebind(`SELECT id FROM users WHERE username = ?`), username); err != nil {
 28		return err
 29	}
 30
 31	query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 32			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 33	ak := sshutils.MarshalAuthorizedKey(pk)
 34	_, err := tx.ExecContext(ctx, query, userID, ak)
 35
 36	return err
 37}
 38
 39// CreateUser implements store.UserStore.
 40func (*userStore) CreateUser(ctx context.Context, tx db.Handler, username string, isAdmin bool, pks []ssh.PublicKey) error {
 41	username = strings.ToLower(username)
 42	if err := utils.ValidateUsername(username); err != nil {
 43		return err
 44	}
 45
 46	query := tx.Rebind(`INSERT INTO users (username, admin, updated_at)
 47			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 48	result, err := tx.ExecContext(ctx, query, username, isAdmin)
 49	if err != nil {
 50		return err
 51	}
 52
 53	userID, err := result.LastInsertId()
 54	if err != nil {
 55		return err
 56	}
 57
 58	for _, pk := range pks {
 59		query := tx.Rebind(`INSERT INTO public_keys (user_id, public_key, updated_at)
 60			VALUES (?, ?, CURRENT_TIMESTAMP);`)
 61		ak := sshutils.MarshalAuthorizedKey(pk)
 62		_, err := tx.ExecContext(ctx, query, userID, ak)
 63		if err != nil {
 64			return err
 65		}
 66	}
 67
 68	return nil
 69}
 70
 71// DeleteUserByUsername implements store.UserStore.
 72func (*userStore) DeleteUserByUsername(ctx context.Context, tx db.Handler, username string) error {
 73	username = strings.ToLower(username)
 74	if err := utils.ValidateUsername(username); err != nil {
 75		return err
 76	}
 77
 78	query := tx.Rebind(`DELETE FROM users WHERE username = ?;`)
 79	_, err := tx.ExecContext(ctx, query, username)
 80	return err
 81}
 82
 83// GetUserByID implements store.UserStore.
 84func (*userStore) GetUserByID(ctx context.Context, tx db.Handler, id int64) (models.User, error) {
 85	var m models.User
 86	query := tx.Rebind(`SELECT * FROM users WHERE id = ?;`)
 87	err := tx.GetContext(ctx, &m, query, id)
 88	return m, err
 89}
 90
 91// FindUserByPublicKey implements store.UserStore.
 92func (*userStore) FindUserByPublicKey(ctx context.Context, tx db.Handler, pk ssh.PublicKey) (models.User, error) {
 93	var m models.User
 94	query := tx.Rebind(`SELECT users.*
 95			FROM users
 96			INNER JOIN public_keys ON users.id = public_keys.user_id
 97			WHERE public_keys.public_key = ?;`)
 98	err := tx.GetContext(ctx, &m, query, sshutils.MarshalAuthorizedKey(pk))
 99	return m, err
100}
101
102// FindUserByUsername implements store.UserStore.
103func (*userStore) FindUserByUsername(ctx context.Context, tx db.Handler, username string) (models.User, error) {
104	username = strings.ToLower(username)
105	if err := utils.ValidateUsername(username); err != nil {
106		return models.User{}, err
107	}
108
109	var m models.User
110	query := tx.Rebind(`SELECT * FROM users WHERE username = ?;`)
111	err := tx.GetContext(ctx, &m, query, username)
112	return m, err
113}
114
115// FindUserByAccessToken implements store.UserStore.
116func (*userStore) FindUserByAccessToken(ctx context.Context, tx db.Handler, token string) (models.User, error) {
117	var m models.User
118	query := tx.Rebind(`SELECT users.*
119			FROM users
120			INNER JOIN access_tokens ON users.id = access_tokens.user_id
121			WHERE access_tokens.token = ?;`)
122	err := tx.GetContext(ctx, &m, query, token)
123	return m, err
124}
125
126// GetAllUsers implements store.UserStore.
127func (*userStore) GetAllUsers(ctx context.Context, tx db.Handler) ([]models.User, error) {
128	var ms []models.User
129	query := tx.Rebind(`SELECT * FROM users;`)
130	err := tx.SelectContext(ctx, &ms, query)
131	return ms, err
132}
133
134// ListPublicKeysByUserID implements store.UserStore..
135func (*userStore) ListPublicKeysByUserID(ctx context.Context, tx db.Handler, id int64) ([]ssh.PublicKey, error) {
136	var aks []string
137	query := tx.Rebind(`SELECT public_key FROM public_keys
138			WHERE user_id = ?
139			ORDER BY public_keys.id ASC;`)
140	err := tx.SelectContext(ctx, &aks, query, id)
141	if err != nil {
142		return nil, err
143	}
144
145	pks := make([]ssh.PublicKey, len(aks))
146	for i, ak := range aks {
147		pk, _, err := sshutils.ParseAuthorizedKey(ak)
148		if err != nil {
149			return nil, err
150		}
151		pks[i] = pk
152	}
153
154	return pks, nil
155}
156
157// ListPublicKeysByUsername implements store.UserStore.
158func (*userStore) ListPublicKeysByUsername(ctx context.Context, tx db.Handler, username string) ([]ssh.PublicKey, error) {
159	username = strings.ToLower(username)
160	if err := utils.ValidateUsername(username); err != nil {
161		return nil, err
162	}
163
164	var aks []string
165	query := tx.Rebind(`SELECT public_key FROM public_keys
166			INNER JOIN users ON users.id = public_keys.user_id
167			WHERE users.username = ?
168			ORDER BY public_keys.id ASC;`)
169	err := tx.SelectContext(ctx, &aks, query, username)
170	if err != nil {
171		return nil, err
172	}
173
174	pks := make([]ssh.PublicKey, len(aks))
175	for i, ak := range aks {
176		pk, _, err := sshutils.ParseAuthorizedKey(ak)
177		if err != nil {
178			return nil, err
179		}
180		pks[i] = pk
181	}
182
183	return pks, nil
184}
185
186// RemovePublicKeyByUsername implements store.UserStore.
187func (*userStore) RemovePublicKeyByUsername(ctx context.Context, tx db.Handler, username string, pk ssh.PublicKey) error {
188	username = strings.ToLower(username)
189	if err := utils.ValidateUsername(username); err != nil {
190		return err
191	}
192
193	query := tx.Rebind(`DELETE FROM public_keys
194			WHERE user_id = (SELECT id FROM users WHERE username = ?)
195			AND public_key = ?;`)
196	_, err := tx.ExecContext(ctx, query, username, sshutils.MarshalAuthorizedKey(pk))
197	return err
198}
199
200// SetAdminByUsername implements store.UserStore.
201func (*userStore) SetAdminByUsername(ctx context.Context, tx db.Handler, username string, isAdmin bool) error {
202	username = strings.ToLower(username)
203	if err := utils.ValidateUsername(username); err != nil {
204		return err
205	}
206
207	query := tx.Rebind(`UPDATE users SET admin = ? WHERE username = ?;`)
208	_, err := tx.ExecContext(ctx, query, isAdmin, username)
209	return err
210}
211
212// SetUsernameByUsername implements store.UserStore.
213func (*userStore) SetUsernameByUsername(ctx context.Context, tx db.Handler, username string, newUsername string) error {
214	username = strings.ToLower(username)
215	if err := utils.ValidateUsername(username); err != nil {
216		return err
217	}
218
219	newUsername = strings.ToLower(newUsername)
220	if err := utils.ValidateUsername(newUsername); err != nil {
221		return err
222	}
223
224	query := tx.Rebind(`UPDATE users SET username = ? WHERE username = ?;`)
225	_, err := tx.ExecContext(ctx, query, newUsername, username)
226	return err
227}
228
229// SetUserPassword implements store.UserStore.
230func (*userStore) SetUserPassword(ctx context.Context, tx db.Handler, userID int64, password string) error {
231	query := tx.Rebind(`UPDATE users SET password = ? WHERE id = ?;`)
232	_, err := tx.ExecContext(ctx, query, password, userID)
233	return err
234}
235
236// SetUserPasswordByUsername implements store.UserStore.
237func (*userStore) SetUserPasswordByUsername(ctx context.Context, tx db.Handler, username string, password string) error {
238	username = strings.ToLower(username)
239	if err := utils.ValidateUsername(username); err != nil {
240		return err
241	}
242
243	query := tx.Rebind(`UPDATE users SET password = ? WHERE username = ?;`)
244	_, err := tx.ExecContext(ctx, query, password, username)
245	return err
246}