users.go

  1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
  2//
  3// SPDX-License-Identifier: Apache-2.0
  4
  5package db
  6
  7import (
  8	"database/sql"
  9	"fmt"
 10	"time"
 11)
 12
 13// DeleteUser deletes specific user from the database and returns an error if it
 14// fails.
 15func DeleteUser(db *sql.DB, user string) error {
 16	mutex.Lock()
 17	defer mutex.Unlock()
 18
 19	_, err := db.Exec("DELETE FROM users WHERE username = ?", user)
 20	if err != nil {
 21		return fmt.Errorf("failed to execute SQL: %w", err)
 22	}
 23
 24	return nil
 25}
 26
 27// CreateUser creates a new user in the database and returns an error if it fails.
 28func CreateUser(db *sql.DB, username, hash, salt string) error {
 29	mutex.Lock()
 30	defer mutex.Unlock()
 31
 32	_, err := db.Exec("INSERT INTO users (username, hash, salt) VALUES (?, ?, ?)", username, hash, salt)
 33	if err != nil {
 34		return fmt.Errorf("failed to execute SQL: %w", err)
 35	}
 36
 37	return nil
 38}
 39
 40// GetUser returns a user's hash and salt from the database as strings and
 41// returns an error if it fails.
 42func GetUser(db *sql.DB, username string) (string, string, error) {
 43	var hash, salt string
 44
 45	err := db.QueryRow("SELECT hash, salt FROM users WHERE username = ?", username).Scan(&hash, &salt)
 46	if err != nil {
 47		return "", "", fmt.Errorf("failed to scan row: %w", err)
 48	}
 49
 50	return hash, salt, nil
 51}
 52
 53// GetUsers returns a list of all users in the database as a slice of strings
 54// and returns an error if it fails.
 55func GetUsers(db *sql.DB) ([]string, error) {
 56	rows, err := db.Query("SELECT username FROM users")
 57	if err != nil {
 58		return nil, fmt.Errorf("failed to query database: %w", err)
 59	}
 60	defer rows.Close()
 61
 62	var users []string
 63	for rows.Next() {
 64		var user string
 65
 66		err = rows.Scan(&user)
 67		if err != nil {
 68			return nil, fmt.Errorf("failed to scan row: %w", err)
 69		}
 70
 71		users = append(users, user)
 72	}
 73
 74	if err = rows.Err(); err != nil {
 75		return nil, fmt.Errorf("failed to iterate rows: %w", err)
 76	}
 77
 78	return users, nil
 79}
 80
 81// GetSession accepts a session ID and returns the username associated with it
 82// and an error.
 83func GetSession(db *sql.DB, session string) (string, time.Time, error) {
 84	var (
 85		username      string
 86		expiresString string
 87	)
 88
 89	err := db.QueryRow("SELECT username, expires FROM sessions WHERE token = ?", session).Scan(&username, &expiresString)
 90	if err != nil {
 91		return "", time.Time{}, fmt.Errorf("failed to scan row: %w", err)
 92	}
 93
 94	expires, err := time.Parse(time.RFC3339, expiresString)
 95	if err != nil {
 96		return "", time.Time{}, fmt.Errorf("failed to parse time: %w", err)
 97	}
 98
 99	return username, expires, nil
100}
101
102// InvalidateSession invalidates a session by setting the expiration date to the
103// provided time.
104func InvalidateSession(db *sql.DB, session string, expiry time.Time) error {
105	mutex.Lock()
106	defer mutex.Unlock()
107
108	_, err := db.Exec("UPDATE sessions SET expires = ? WHERE token = ?", expiry.Format(time.RFC3339), session)
109	if err != nil {
110		return fmt.Errorf("failed to execute SQL: %w", err)
111	}
112
113	return nil
114}
115
116// CreateSession creates a new session in the database and returns an error if
117// it fails.
118func CreateSession(db *sql.DB, username, token string, expiry time.Time) error {
119	mutex.Lock()
120	defer mutex.Unlock()
121
122	_, err := db.Exec("INSERT INTO sessions (token, username, expires) VALUES (?, ?, ?)", token, username, expiry.Format(time.RFC3339))
123	if err != nil {
124		return fmt.Errorf("failed to execute SQL: %w", err)
125	}
126
127	return nil
128}