1// SPDX-FileCopyrightText: Amolith <amolith@secluded.site>
2//
3// SPDX-License-Identifier: Apache-2.0
4
5package db
6
7import (
8 "database/sql"
9 "fmt"
10 "time"
11)
12
13// DeleteUser deletes specific user from the database and returns an error if it
14// fails.
15func DeleteUser(db *sql.DB, user string) error {
16 mutex.Lock()
17 defer mutex.Unlock()
18
19 _, err := db.Exec("DELETE FROM users WHERE username = ?", user)
20 if err != nil {
21 return fmt.Errorf("failed to execute SQL: %w", err)
22 }
23
24 return nil
25}
26
27// CreateUser creates a new user in the database and returns an error if it fails.
28func CreateUser(db *sql.DB, username, hash, salt string) error {
29 mutex.Lock()
30 defer mutex.Unlock()
31
32 _, err := db.Exec("INSERT INTO users (username, hash, salt) VALUES (?, ?, ?)", username, hash, salt)
33 if err != nil {
34 return fmt.Errorf("failed to execute SQL: %w", err)
35 }
36
37 return nil
38}
39
40// GetUser returns a user's hash and salt from the database as strings and
41// returns an error if it fails.
42func GetUser(db *sql.DB, username string) (string, string, error) {
43 var hash, salt string
44
45 err := db.QueryRow("SELECT hash, salt FROM users WHERE username = ?", username).Scan(&hash, &salt)
46 if err != nil {
47 return "", "", fmt.Errorf("failed to scan row: %w", err)
48 }
49
50 return hash, salt, nil
51}
52
53// GetUsers returns a list of all users in the database as a slice of strings
54// and returns an error if it fails.
55func GetUsers(db *sql.DB) ([]string, error) {
56 rows, err := db.Query("SELECT username FROM users")
57 if err != nil {
58 return nil, fmt.Errorf("failed to query database: %w", err)
59 }
60 defer rows.Close()
61
62 var users []string
63 for rows.Next() {
64 var user string
65
66 err = rows.Scan(&user)
67 if err != nil {
68 return nil, fmt.Errorf("failed to scan row: %w", err)
69 }
70
71 users = append(users, user)
72 }
73
74 if err = rows.Err(); err != nil {
75 return nil, fmt.Errorf("failed to iterate rows: %w", err)
76 }
77
78 return users, nil
79}
80
81// GetSession accepts a session ID and returns the username associated with it
82// and an error.
83func GetSession(db *sql.DB, session string) (string, time.Time, error) {
84 var (
85 username string
86 expiresString string
87 )
88
89 err := db.QueryRow("SELECT username, expires FROM sessions WHERE token = ?", session).Scan(&username, &expiresString)
90 if err != nil {
91 return "", time.Time{}, fmt.Errorf("failed to scan row: %w", err)
92 }
93
94 expires, err := time.Parse(time.RFC3339, expiresString)
95 if err != nil {
96 return "", time.Time{}, fmt.Errorf("failed to parse time: %w", err)
97 }
98
99 return username, expires, nil
100}
101
102// InvalidateSession invalidates a session by setting the expiration date to the
103// provided time.
104func InvalidateSession(db *sql.DB, session string, expiry time.Time) error {
105 mutex.Lock()
106 defer mutex.Unlock()
107
108 _, err := db.Exec("UPDATE sessions SET expires = ? WHERE token = ?", expiry.Format(time.RFC3339), session)
109 if err != nil {
110 return fmt.Errorf("failed to execute SQL: %w", err)
111 }
112
113 return nil
114}
115
116// CreateSession creates a new session in the database and returns an error if
117// it fails.
118func CreateSession(db *sql.DB, username, token string, expiry time.Time) error {
119 mutex.Lock()
120 defer mutex.Unlock()
121
122 _, err := db.Exec(
123 "INSERT INTO sessions (token, username, expires) VALUES (?, ?, ?)",
124 token,
125 username,
126 expiry.Format(time.RFC3339),
127 )
128 if err != nil {
129 return fmt.Errorf("failed to execute SQL: %w", err)
130 }
131
132 return nil
133}