1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5package users
  6
  7import (
  8	"crypto/rand"
  9	"database/sql"
 10	"encoding/base64"
 11	"fmt"
 12	"time"
 13
 14	"git.sr.ht/~amolith/willow/db"
 15	"golang.org/x/crypto/argon2"
 16)
 17
 18const (
 19	argon2Time    = 2
 20	saltLength    = 16
 21	argon2KeyLen  = 64
 22	argon2Memory  = 64 * 1024
 23	argon2Threads = 4
 24)
 25
 26// argonHash accepts two strings for the user's password and a random salt,
 27// hashes the password using the salt, and returns the hash as a base64-encoded
 28// string.
 29func argonHash(password, salt string) (string, error) {
 30	decodedSalt, err := base64.StdEncoding.DecodeString(salt)
 31	if err != nil {
 32		return "", fmt.Errorf("failed to decode base64: %w", err)
 33	}
 34
 35	return base64.StdEncoding.EncodeToString(argon2.IDKey([]byte(password), decodedSalt, argon2Time, argon2Memory, argon2Threads, argon2KeyLen)), nil
 36}
 37
 38// generateSalt generates a random salt and returns it as a base64-encoded
 39// string.
 40func generateSalt() (string, error) {
 41	salt := make([]byte, saltLength)
 42
 43	_, err := rand.Read(salt)
 44	if err != nil {
 45		return "", fmt.Errorf("failed to generate random bytes: %w", err)
 46	}
 47
 48	return base64.StdEncoding.EncodeToString(salt), nil
 49}
 50
 51// Register accepts a username and password, hashes the password and stores the
 52// hash and salt in the database.
 53func Register(dbConn *sql.DB, username, password string) error {
 54	salt, err := generateSalt()
 55	if err != nil {
 56		return err
 57	}
 58
 59	hash, err := argonHash(password, salt)
 60	if err != nil {
 61		return err
 62	}
 63
 64	err = db.CreateUser(dbConn, username, hash, salt)
 65	if err != nil {
 66		return fmt.Errorf("failed to create user: %w", err)
 67	}
 68
 69	return nil
 70}
 71
 72// Delete removes a user from the database.
 73func Delete(dbConn *sql.DB, username string) error {
 74	err := db.DeleteUser(dbConn, username)
 75	if err != nil {
 76		return fmt.Errorf("failed to delete user: %w", err)
 77	}
 78
 79	return nil
 80}
 81
 82// UserAuthorised accepts a username string, a token string, and returns true if the
 83// user is authorised, false if not, and an error if one is encountered.
 84func UserAuthorised(dbConn *sql.DB, username, token string) (bool, error) {
 85	dbHash, dbSalt, err := db.GetUser(dbConn, username)
 86	if err != nil {
 87		return false, fmt.Errorf("failed to get user: %w", err)
 88	}
 89
 90	providedHash, err := argonHash(token, dbSalt)
 91	if err != nil {
 92		return false, err
 93	}
 94
 95	return dbHash == providedHash, nil
 96}
 97
 98// SessionAuthorised accepts a session string and returns true if the session is
 99// valid and false if not.
100func SessionAuthorised(dbConn *sql.DB, session string) (bool, error) {
101	dbResult, expiry, err := db.GetSession(dbConn, session)
102	if err != nil {
103		return false, fmt.Errorf("failed to get session: %w", err)
104	}
105
106	if dbResult == "" || expiry.Before(time.Now()) {
107		return false, nil
108	}
109
110	return true, nil
111}
112
113// InvalidateSession invalidates a session by setting the expiration date to now.
114func InvalidateSession(dbConn *sql.DB, session string) error {
115	err := db.InvalidateSession(dbConn, session, time.Now())
116	if err != nil {
117		return fmt.Errorf("failed to invalidate session: %w", err)
118	}
119
120	return nil
121}
122
123// CreateSession accepts a username, generates a token, stores it in the
124// database, and returns it.
125func CreateSession(dbConn *sql.DB, username string) (string, time.Time, error) {
126	token, err := generateSalt()
127	if err != nil {
128		return "", time.Time{}, err
129	}
130
131	expiry := time.Now().Add(7 * 24 * time.Hour)
132
133	err = db.CreateSession(dbConn, username, token, expiry)
134	if err != nil {
135		return "", time.Time{}, fmt.Errorf("failed to create session: %w", err)
136	}
137
138	return token, expiry, nil
139}
140
141// GetUsers returns a list of all users in the database as a slice of strings.
142func GetUsers(dbConn *sql.DB) ([]string, error) {
143	users, err := db.GetUsers(dbConn)
144	if err != nil {
145		return nil, fmt.Errorf("failed to get users: %w", err)
146	}
147
148	return users, nil
149}