users.go

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5package users
  6
  7import (
  8	"crypto/rand"
  9	"database/sql"
 10	"encoding/base64"
 11	"fmt"
 12	"time"
 13
 14	"git.sr.ht/~amolith/willow/db"
 15	"golang.org/x/crypto/argon2"
 16)
 17
 18const (
 19	argon2Time    = 2
 20	saltLength    = 16
 21	argon2KeyLen  = 64
 22	argon2Memory  = 64 * 1024
 23	argon2Threads = 4
 24)
 25
 26// argonHash accepts two strings for the user's password and a random salt,
 27// hashes the password using the salt, and returns the hash as a base64-encoded
 28// string.
 29func argonHash(password, salt string) (string, error) {
 30	decodedSalt, err := base64.StdEncoding.DecodeString(salt)
 31	if err != nil {
 32		return "", fmt.Errorf("failed to decode base64: %w", err)
 33	}
 34
 35	return base64.StdEncoding.EncodeToString(argon2.IDKey(
 36		[]byte(password),
 37		decodedSalt,
 38		argon2Time,
 39		argon2Memory,
 40		argon2Threads,
 41		argon2KeyLen,
 42	)), nil
 43}
 44
 45// generateSalt generates a random salt and returns it as a base64-encoded
 46// string.
 47func generateSalt() (string, error) {
 48	salt := make([]byte, saltLength)
 49
 50	_, err := rand.Read(salt)
 51	if err != nil {
 52		return "", fmt.Errorf("failed to generate random bytes: %w", err)
 53	}
 54
 55	return base64.StdEncoding.EncodeToString(salt), nil
 56}
 57
 58// Register accepts a username and password, hashes the password and stores the
 59// hash and salt in the database.
 60func Register(dbConn *sql.DB, username, password string) error {
 61	salt, err := generateSalt()
 62	if err != nil {
 63		return err
 64	}
 65
 66	hash, err := argonHash(password, salt)
 67	if err != nil {
 68		return err
 69	}
 70
 71	err = db.CreateUser(dbConn, username, hash, salt)
 72	if err != nil {
 73		return fmt.Errorf("failed to create user: %w", err)
 74	}
 75
 76	return nil
 77}
 78
 79// Delete removes a user from the database.
 80func Delete(dbConn *sql.DB, username string) error {
 81	err := db.DeleteUser(dbConn, username)
 82	if err != nil {
 83		return fmt.Errorf("failed to delete user: %w", err)
 84	}
 85
 86	return nil
 87}
 88
 89// UserAuthorised accepts a username string, a token string, and returns true if the
 90// user is authorised, false if not, and an error if one is encountered.
 91func UserAuthorised(dbConn *sql.DB, username, token string) (bool, error) {
 92	dbHash, dbSalt, err := db.GetUser(dbConn, username)
 93	if err != nil {
 94		return false, fmt.Errorf("failed to get user: %w", err)
 95	}
 96
 97	providedHash, err := argonHash(token, dbSalt)
 98	if err != nil {
 99		return false, err
100	}
101
102	return dbHash == providedHash, nil
103}
104
105// SessionAuthorised accepts a session string and returns true if the session is
106// valid and false if not.
107func SessionAuthorised(dbConn *sql.DB, session string) (bool, error) {
108	dbResult, expiry, err := db.GetSession(dbConn, session)
109	if err != nil {
110		return false, fmt.Errorf("failed to get session: %w", err)
111	}
112
113	if dbResult == "" || expiry.Before(time.Now()) {
114		return false, nil
115	}
116
117	return true, nil
118}
119
120// InvalidateSession invalidates a session by setting the expiration date to now.
121func InvalidateSession(dbConn *sql.DB, session string) error {
122	err := db.InvalidateSession(dbConn, session, time.Now())
123	if err != nil {
124		return fmt.Errorf("failed to invalidate session: %w", err)
125	}
126
127	return nil
128}
129
130// CreateSession accepts a username, generates a token, stores it in the
131// database, and returns it.
132func CreateSession(dbConn *sql.DB, username string) (string, time.Time, error) {
133	token, err := generateSalt()
134	if err != nil {
135		return "", time.Time{}, err
136	}
137
138	expiry := time.Now().Add(7 * 24 * time.Hour)
139
140	err = db.CreateSession(dbConn, username, token, expiry)
141	if err != nil {
142		return "", time.Time{}, fmt.Errorf("failed to create session: %w", err)
143	}
144
145	return token, expiry, nil
146}
147
148// GetUsers returns a list of all users in the database as a slice of strings.
149func GetUsers(dbConn *sql.DB) ([]string, error) {
150	users, err := db.GetUsers(dbConn)
151	if err != nil {
152		return nil, fmt.Errorf("failed to get users: %w", err)
153	}
154
155	return users, nil
156}